prettified all the code using "blue" at the urging of @tildebyte

This commit is contained in:
Lincoln Stein 2022-08-26 03:15:42 -04:00
parent dd670200bb
commit 4f02b72c9c
35 changed files with 6252 additions and 3119 deletions

View File

@ -1,11 +1,17 @@
from abc import abstractmethod from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset from torch.utils.data import (
Dataset,
ConcatDataset,
ChainDataset,
IterableDataset,
)
class Txt2ImgIterableBaseDataset(IterableDataset): class Txt2ImgIterableBaseDataset(IterableDataset):
''' """
Define an interface to make the IterableDatasets for text2img data chainable Define an interface to make the IterableDatasets for text2img data chainable
''' """
def __init__(self, num_records=0, valid_ids=None, size=256): def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__() super().__init__()
self.num_records = num_records self.num_records = num_records
@ -13,7 +19,9 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids self.sample_ids = valid_ids
self.size = size self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self): def __len__(self):
return self.num_records return self.num_records

View File

@ -11,13 +11,21 @@ from tqdm import tqdm
from torch.utils.data import Dataset, Subset from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve from taming.data.imagenet import (
str_to_indices,
give_synsets_from_indices,
download,
retrieve,
)
from taming.data.imagenet import ImagePaths from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from ldm.modules.image_degradation import (
degradation_fn_bsr,
degradation_fn_bsr_light,
)
def synset2idx(path_to_yaml="data/index_synset.yaml"): def synset2idx(path_to_yaml='data/index_synset.yaml'):
with open(path_to_yaml) as f: with open(path_to_yaml) as f:
di2s = yaml.load(f) di2s = yaml.load(f)
return dict((v, k) for k, v in di2s.items()) return dict((v, k) for k, v in di2s.items())
@ -28,7 +36,9 @@ class ImageNetBase(Dataset):
self.config = config or OmegaConf.create() self.config = config or OmegaConf.create()
if not type(self.config) == dict: if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config) self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) self.keep_orig_class_label = self.config.get(
'keep_orig_class_label', False
)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare() self._prepare()
self._prepare_synset_to_human() self._prepare_synset_to_human()
@ -46,17 +56,23 @@ class ImageNetBase(Dataset):
raise NotImplementedError() raise NotImplementedError()
def _filter_relpaths(self, relpaths): def _filter_relpaths(self, relpaths):
ignore = set([ ignore = set(
"n06596364_9591.JPEG", [
]) 'n06596364_9591.JPEG',
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] ]
if "sub_indices" in self.config: )
indices = str_to_indices(self.config["sub_indices"]) relpaths = [
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
]
if 'sub_indices' in self.config:
indices = str_to_indices(self.config['sub_indices'])
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = [] files = []
for rpath in relpaths: for rpath in relpaths:
syn = rpath.split("/")[0] syn = rpath.split('/')[0]
if syn in synsets: if syn in synsets:
files.append(rpath) files.append(rpath)
return files return files
@ -65,64 +81,75 @@ class ImageNetBase(Dataset):
def _prepare_synset_to_human(self): def _prepare_synset_to_human(self):
SIZE = 2655750 SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
self.human_dict = os.path.join(self.root, "synset_human.txt") self.human_dict = os.path.join(self.root, 'synset_human.txt')
if (not os.path.exists(self.human_dict) or if (
not os.path.getsize(self.human_dict)==SIZE): not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
download(URL, self.human_dict) download(URL, self.human_dict)
def _prepare_idx_to_synset(self): def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
self.idx2syn = os.path.join(self.root, "index_synset.yaml") self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
if (not os.path.exists(self.idx2syn)): if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn) download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self): def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") self.human2integer = os.path.join(
if (not os.path.exists(self.human2integer)): self.root, 'imagenet1000_clsidx_to_labels.txt'
)
if not os.path.exists(self.human2integer):
download(URL, self.human2integer) download(URL, self.human2integer)
with open(self.human2integer, "r") as f: with open(self.human2integer, 'r') as f:
lines = f.read().splitlines() lines = f.read().splitlines()
assert len(lines) == 1000 assert len(lines) == 1000
self.human2integer_dict = dict() self.human2integer_dict = dict()
for line in lines: for line in lines:
value, key = line.split(":") value, key = line.split(':')
self.human2integer_dict[key] = int(value) self.human2integer_dict[key] = int(value)
def _load(self): def _load(self):
with open(self.txt_filelist, "r") as f: with open(self.txt_filelist, 'r') as f:
self.relpaths = f.read().splitlines() self.relpaths = f.read().splitlines()
l1 = len(self.relpaths) l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) print(
'Removed {} files from filelist during filtering.'.format(
l1 - len(self.relpaths)
)
)
self.synsets = [p.split("/")[0] for p in self.relpaths] self.synsets = [p.split('/')[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets) unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) class_dict = dict(
(synset, i) for i, synset in enumerate(unique_synsets)
)
if not self.keep_orig_class_label: if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets] self.class_labels = [class_dict[s] for s in self.synsets]
else: else:
self.class_labels = [self.synset2idx[s] for s in self.synsets] self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f: with open(self.human_dict, 'r') as f:
human_dict = f.read().splitlines() human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict) human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets] self.human_labels = [human_dict[s] for s in self.synsets]
labels = { labels = {
"relpath": np.array(self.relpaths), 'relpath': np.array(self.relpaths),
"synsets": np.array(self.synsets), 'synsets': np.array(self.synsets),
"class_label": np.array(self.class_labels), 'class_label': np.array(self.class_labels),
"human_label": np.array(self.human_labels), 'human_label': np.array(self.human_labels),
} }
if self.process_images: if self.process_images:
self.size = retrieve(self.config, "size", default=256) self.size = retrieve(self.config, 'size', default=256)
self.data = ImagePaths(self.abspaths, self.data = ImagePaths(
self.abspaths,
labels=labels, labels=labels,
size=self.size, size=self.size,
random_crop=self.random_crop, random_crop=self.random_crop,
@ -132,11 +159,11 @@ class ImageNetBase(Dataset):
class ImageNetTrain(ImageNetBase): class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train" NAME = 'ILSVRC2012_train'
URL = "http://www.image-net.org/challenges/LSVRC/2012/" URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
FILES = [ FILES = [
"ILSVRC2012_img_train.tar", 'ILSVRC2012_img_train.tar',
] ]
SIZES = [ SIZES = [
147897477120, 147897477120,
@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase):
if self.data_root: if self.data_root:
self.root = os.path.join(self.data_root, self.NAME) self.root = os.path.join(self.data_root, self.NAME)
else: else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) cachedir = os.environ.get(
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, "data") self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, "filelist.txt") self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 1281167 self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", self.random_crop = retrieve(
default=True) self.config, 'ImageNetTrain/random_crop', default=True
)
if not tdu.is_prepared(self.root): if not tdu.is_prepared(self.root):
# prep # prep
print("Preparing dataset {} in {}".format(self.NAME, self.root)) print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir datadir = self.datadir
if not os.path.exists(datadir): if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0]) path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root) atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path assert atpath == path
print("Extracting {} to {}".format(path, datadir)) print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True) os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar: with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir) tar.extractall(path=datadir)
print("Extracting sub-tars.") print('Extracting sub-tars.')
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
for subpath in tqdm(subpaths): for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")] subdir = subpath[: -len('.tar')]
os.makedirs(subdir, exist_ok=True) os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar: with tarfile.open(subpath, 'r:') as tar:
tar.extractall(path=subdir) tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist) filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n" filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, "w") as f: with open(self.txt_filelist, 'w') as f:
f.write(filelist) f.write(filelist)
tdu.mark_prepared(self.root) tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase): class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation" NAME = 'ILSVRC2012_validation'
URL = "http://www.image-net.org/challenges/LSVRC/2012/" URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
FILES = [ FILES = [
"ILSVRC2012_img_val.tar", 'ILSVRC2012_img_val.tar',
"validation_synset.txt", 'validation_synset.txt',
] ]
SIZES = [ SIZES = [
6744924160, 6744924160,
@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase):
if self.data_root: if self.data_root:
self.root = os.path.join(self.data_root, self.NAME) self.root = os.path.join(self.data_root, self.NAME)
else: else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) cachedir = os.environ.get(
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
self.datadir = os.path.join(self.root, "data") )
self.txt_filelist = os.path.join(self.root, "filelist.txt") self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 50000 self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", self.random_crop = retrieve(
default=False) self.config, 'ImageNetValidation/random_crop', default=False
)
if not tdu.is_prepared(self.root): if not tdu.is_prepared(self.root):
# prep # prep
print("Preparing dataset {} in {}".format(self.NAME, self.root)) print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir datadir = self.datadir
if not os.path.exists(datadir): if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0]) path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root) atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path assert atpath == path
print("Extracting {} to {}".format(path, datadir)) print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True) os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar: with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir) tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1]) vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
download(self.VS_URL, vspath) download(self.VS_URL, vspath)
with open(vspath, "r") as f: with open(vspath, 'r') as f:
synset_dict = f.read().splitlines() synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict) synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders") print('Reorganizing into synset folders')
synsets = np.unique(list(synset_dict.values())) synsets = np.unique(list(synset_dict.values()))
for s in synsets: for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True) os.makedirs(os.path.join(datadir, s), exist_ok=True)
@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase):
dst = os.path.join(datadir, v) dst = os.path.join(datadir, v)
shutil.move(src, dst) shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist) filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n" filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, "w") as f: with open(self.txt_filelist, 'w') as f:
f.write(filelist) f.write(filelist)
tdu.mark_prepared(self.root) tdu.mark_prepared(self.root)
class ImageNetSR(Dataset): class ImageNetSR(Dataset):
def __init__(self, size=None, def __init__(
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., self,
random_crop=True): size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
""" """
Imagenet Superresolution Dataloader Imagenet Superresolution Dataloader
Performs following ops in order: Performs following ops in order:
@ -296,67 +345,86 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f) self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.) assert max_crop_f <= 1.0
self.center_crop = not random_crop self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
if degradation == "bsrgan": if degradation == 'bsrgan':
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) self.degradation_process = partial(
degradation_fn_bsr, sf=downscale_f
)
elif degradation == "bsrgan_light": elif degradation == 'bsrgan_light':
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) self.degradation_process = partial(
degradation_fn_bsr_light, sf=downscale_f
)
else: else:
interpolation_fn = { interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST, 'cv_nearest': cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR, 'cv_bilinear': cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC, 'cv_bicubic': cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA, 'cv_area': cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4, 'cv_lanczos': cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST, 'pil_nearest': PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR, 'pil_bilinear': PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC, 'pil_bicubic': PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX, 'pil_box': PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING, 'pil_hamming': PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS, 'pil_lanczos': PIL.Image.LANCZOS,
}[degradation] }[degradation]
self.pil_interpolation = degradation.startswith("pil_") self.pil_interpolation = degradation.startswith('pil_')
if self.pil_interpolation: if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) self.degradation_process = partial(
TF.resize,
size=self.LR_size,
interpolation=interpolation_fn,
)
else: else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, self.degradation_process = albumentations.SmallestMaxSize(
interpolation=interpolation_fn) max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self): def __len__(self):
return len(self.base) return len(self.base)
def __getitem__(self, i): def __getitem__(self, i):
example = self.base[i] example = self.base[i]
image = Image.open(example["file_path_"]) image = Image.open(example['file_path_'])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2]) min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
crop_side_len = int(crop_side_len) crop_side_len = int(crop_side_len)
if self.center_crop: if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
else: else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
image = self.cropper(image=image)["image"] image = self.cropper(image=image)['image']
image = self.image_rescaler(image=image)["image"] image = self.image_rescaler(image=image)['image']
if self.pil_interpolation: if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image) image_pil = PIL.Image.fromarray(image)
@ -364,10 +432,10 @@ class ImageNetSR(Dataset):
LR_image = np.array(LR_image).astype(np.uint8) LR_image = np.array(LR_image).astype(np.uint8)
else: else:
LR_image = self.degradation_process(image=image)["image"] LR_image = self.degradation_process(image=image)['image']
example["image"] = (image/127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example return example
@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_base(self): def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f: with open('data/imagenet_train_hr_indices.p', 'rb') as f:
indices = pickle.load(f) indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,) dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices) return Subset(dset, indices)
@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_base(self): def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f: with open('data/imagenet_val_hr_indices.p', 'rb') as f:
indices = pickle.load(f) indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,) dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices) return Subset(dset, indices)

View File

@ -7,29 +7,32 @@ from torchvision import transforms
class LSUNBase(Dataset): class LSUNBase(Dataset):
def __init__(self, def __init__(
self,
txt_file, txt_file,
data_root, data_root,
size=None, size=None,
interpolation="bicubic", interpolation='bicubic',
flip_p=0.5 flip_p=0.5,
): ):
self.data_paths = txt_file self.data_paths = txt_file
self.data_root = data_root self.data_root = data_root
with open(self.data_paths, "r") as f: with open(self.data_paths, 'r') as f:
self.image_paths = f.read().splitlines() self.image_paths = f.read().splitlines()
self._length = len(self.image_paths) self._length = len(self.image_paths)
self.labels = { self.labels = {
"relative_file_path_": [l for l in self.image_paths], 'relative_file_path_': [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l) 'file_path_': [
for l in self.image_paths], os.path.join(self.data_root, l) for l in self.image_paths
],
} }
self.size = size self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR, self.interpolation = {
"bilinear": PIL.Image.BILINEAR, 'linear': PIL.Image.LINEAR,
"bicubic": PIL.Image.BICUBIC, 'bilinear': PIL.Image.BILINEAR,
"lanczos": PIL.Image.LANCZOS, 'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation] }[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -38,55 +41,86 @@ class LSUNBase(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels) example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"]) image = Image.open(example['file_path_'])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1] h, w, = (
img = img[(h - crop) // 2:(h + crop) // 2, img.shape[0],
(w - crop) // 2:(w + crop) // 2] img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation) image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example return example
class LSUNChurchesTrain(LSUNBase): class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) super().__init__(
txt_file='data/lsun/church_outdoor_train.txt',
data_root='data/lsun/churches',
**kwargs
)
class LSUNChurchesValidation(LSUNBase): class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", super().__init__(
flip_p=flip_p, **kwargs) txt_file='data/lsun/church_outdoor_val.txt',
data_root='data/lsun/churches',
flip_p=flip_p,
**kwargs
)
class LSUNBedroomsTrain(LSUNBase): class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) super().__init__(
txt_file='data/lsun/bedrooms_train.txt',
data_root='data/lsun/bedrooms',
**kwargs
)
class LSUNBedroomsValidation(LSUNBase): class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", super().__init__(
flip_p=flip_p, **kwargs) txt_file='data/lsun/bedrooms_val.txt',
data_root='data/lsun/bedrooms',
flip_p=flip_p,
**kwargs
)
class LSUNCatsTrain(LSUNBase): class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) super().__init__(
txt_file='data/lsun/cat_train.txt',
data_root='data/lsun/cats',
**kwargs
)
class LSUNCatsValidation(LSUNBase): class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", super().__init__(
flip_p=flip_p, **kwargs) txt_file='data/lsun/cat_val.txt',
data_root='data/lsun/cats',
flip_p=flip_p,
**kwargs
)

View File

@ -72,18 +72,41 @@ imagenet_dual_templates_small = [
] ]
per_img_token_list = [ per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
] ]
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, def __init__(
self,
data_root, data_root,
size=None, size=None,
repeats=100, repeats=100,
interpolation="bicubic", interpolation='bicubic',
flip_p=0.5, flip_p=0.5,
set="train", set='train',
placeholder_token="*", placeholder_token='*',
per_image_tokens=False, per_image_tokens=False,
center_crop=False, center_crop=False,
mixing_prob=0.25, mixing_prob=0.25,
@ -92,7 +115,10 @@ class PersonalizedBase(Dataset):
self.data_root = data_root self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths) # self._length = len(self.image_paths)
self.num_images = len(self.image_paths) self.num_images = len(self.image_paths)
@ -107,16 +133,19 @@ class PersonalizedBase(Dataset):
self.coarse_class_text = coarse_class_text self.coarse_class_text = coarse_class_text
if per_image_tokens: if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train": if set == 'train':
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.size = size self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR, self.interpolation = {
"bilinear": PIL.Image.BILINEAR, 'linear': PIL.Image.LINEAR,
"bicubic": PIL.Image.BICUBIC, 'bilinear': PIL.Image.BILINEAR,
"lanczos": PIL.Image.LANCZOS, 'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation] }[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -127,34 +156,47 @@ class PersonalizedBase(Dataset):
example = {} example = {}
image = Image.open(self.image_paths[i % self.num_images]) image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
placeholder_string = self.placeholder_token placeholder_string = self.placeholder_token
if self.coarse_class_text: if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}" placeholder_string = (
f'{self.coarse_class_text} {placeholder_string}'
)
if self.per_image_tokens and np.random.uniform() < self.mixing_prob: if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else: else:
text = random.choice(imagenet_templates_small).format(placeholder_string) text = random.choice(imagenet_templates_small).format(
placeholder_string
)
example["caption"] = text example['caption'] = text
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1] h, w, = (
img = img[(h - crop) // 2:(h + crop) // 2, img.shape[0],
(w - crop) // 2:(w + crop) // 2] img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation) image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example return example

View File

@ -50,25 +50,51 @@ imagenet_dual_templates_small = [
] ]
per_img_token_list = [ per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
] ]
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, def __init__(
self,
data_root, data_root,
size=None, size=None,
repeats=100, repeats=100,
interpolation="bicubic", interpolation='bicubic',
flip_p=0.5, flip_p=0.5,
set="train", set='train',
placeholder_token="*", placeholder_token='*',
per_image_tokens=False, per_image_tokens=False,
center_crop=False, center_crop=False,
): ):
self.data_root = data_root self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths) # self._length = len(self.image_paths)
self.num_images = len(self.image_paths) self.num_images = len(self.image_paths)
@ -80,16 +106,19 @@ class PersonalizedBase(Dataset):
self.center_crop = center_crop self.center_crop = center_crop
if per_image_tokens: if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train": if set == 'train':
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.size = size self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR, self.interpolation = {
"bilinear": PIL.Image.BILINEAR, 'linear': PIL.Image.LINEAR,
"bicubic": PIL.Image.BICUBIC, 'bilinear': PIL.Image.BILINEAR,
"lanczos": PIL.Image.LANCZOS, 'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation] }[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -100,30 +129,41 @@ class PersonalizedBase(Dataset):
example = {} example = {}
image = Image.open(self.image_paths[i % self.num_images]) image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
if self.per_image_tokens and np.random.uniform() < 0.25: if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) text = random.choice(imagenet_dual_templates_small).format(
self.placeholder_token, per_img_token_list[i % self.num_images]
)
else: else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token) text = random.choice(imagenet_templates_small).format(
self.placeholder_token
)
example["caption"] = text example['caption'] = text
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1] h, w, = (
img = img[(h - crop) // 2:(h + crop) // 2, img.shape[0],
(w - crop) // 2:(w + crop) // 2] img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation) image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example return example

View File

@ -1,4 +1,4 @@
''' """
Two helper classes for dealing with PNG images and their path names. Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata appropriate names for them, and writes prompt metadata
@ -7,7 +7,7 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
prompt for file/directory names. prompt for file/directory names.
PromptFormatter -- Utility for converting a Namespace of prompt parameters PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches. back into a formatted prompt string with command-line switches.
''' """
import os import os
import re import re
from math import sqrt, floor, ceil from math import sqrt, floor, ceil
@ -15,7 +15,6 @@ from PIL import Image,PngImagePlugin
# -------------------image generation utils----- # -------------------image generation utils-----
class PngWriter: class PngWriter:
def __init__(self, outdir, prompt=None, batch_size=1): def __init__(self, outdir, prompt=None, batch_size=1):
self.outdir = outdir self.outdir = outdir
self.batch_size = batch_size self.batch_size = batch_size
@ -25,7 +24,9 @@ class PngWriter:
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
def write_image(self, image, seed): def write_image(self, image, seed):
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way self.filepath = self.unique_filename(
seed, self.filepath
) # will increment name in some sensible way
try: try:
prompt = f'{self.prompt} -S{seed}' prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image, prompt, self.filepath) self.save_image_and_prompt_to_png(image, prompt, self.filepath)
@ -40,7 +41,10 @@ class PngWriter:
# sort reverse alphabetically until we find max+1 # sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True) dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png # find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') filename = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(filename.split('.', 1)[0]) basecount = int(filename.split('.', 1)[0])
basecount += 1 basecount += 1
if self.batch_size > 1: if self.batch_size > 1:
@ -61,15 +65,19 @@ class PngWriter:
while not finished: while not finished:
series += 1 series += 1
filename = f'{basecount:06}.{seed}.png' filename = f'{basecount:06}.{seed}.png'
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): if self.batch_size > 1 or os.path.exists(
os.path.join(self.outdir, filename)
):
filename = f'{basecount:06}.{seed}.{series:02}.png' filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(self.outdir,filename)) finished = not os.path.exists(
os.path.join(self.outdir, filename)
)
return os.path.join(self.outdir, filename) return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self, image, prompt, path): def save_image_and_prompt_to_png(self, image, prompt, path):
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt) info.add_text('Dream', prompt)
image.save(path,"PNG",pnginfo=info) image.save(path, 'PNG', pnginfo=info)
def make_grid(self, image_list, rows=None, cols=None): def make_grid(self, image_list, rows=None, cols=None):
image_cnt = len(image_list) image_cnt = len(image_list)
@ -87,13 +95,14 @@ class PngWriter:
return grid_img return grid_img
class PromptFormatter():
class PromptFormatter:
def __init__(self, t2i, opt): def __init__(self, t2i, opt):
self.t2i = t2i self.t2i = t2i
self.opt = opt self.opt = opt
def normalize_prompt(self): def normalize_prompt(self):
'''Normalize the prompt and switches''' """Normalize the prompt and switches"""
t2i = self.t2i t2i = self.t2i
opt = self.opt opt = self.opt
@ -114,4 +123,3 @@ class PromptFormatter():
if t2i.full_precision: if t2i.full_precision:
switches.append('-F') switches.append('-F')
return ' '.join(switches) return ' '.join(switches)

View File

@ -1,17 +1,20 @@
''' """
Readline helper functions for dream.py (linux and mac only). Readline helper functions for dream.py (linux and mac only).
''' """
import os import os
import re import re
import atexit import atexit
# ---------------readline utilities--------------------- # ---------------readline utilities---------------------
try: try:
import readline import readline
readline_available = True readline_available = True
except: except:
readline_available = False readline_available = False
class Completer():
class Completer:
def __init__(self, options): def __init__(self, options):
self.options = sorted(options) self.options = sorted(options)
return return
@ -29,9 +32,9 @@ class Completer():
if state == 0: if state == 0:
# This is the first time for this text, so build a match list. # This is the first time for this text, so build a match list.
if text: if text:
self.matches = [s self.matches = [
for s in self.options s for s in self.options if s and s.startswith(text)
if s and s.startswith(text)] ]
else: else:
self.matches = self.options[:] self.matches = self.options[:]
@ -66,7 +69,9 @@ class Completer():
full_path = os.path.join(dir, n) full_path = os.path.join(dir, n)
if full_path.startswith(path): if full_path.startswith(path):
if os.path.isdir(full_path): if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/') matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions): elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text), n)) matches.append(os.path.join(os.path.dirname(text), n))
@ -76,19 +81,47 @@ class Completer():
response = None response = None
return response return response
if readline_available: if readline_available:
readline.set_completer(Completer(['cd','pwd', readline.set_completer(
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', Completer(
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', [
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete) 'cd',
readline.set_completer_delims(" ") 'pwd',
'--steps',
'-s',
'--seed',
'-S',
'--iterations',
'-n',
'--batch_size',
'-b',
'--width',
'-W',
'--height',
'-H',
'--cfg_scale',
'-C',
'--grid',
'-g',
'--individual',
'-i',
'--init_img',
'-I',
'--strength',
'-f',
'-v',
'--variants',
]
).complete
)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete') readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'),".dream_history") histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try: try:
readline.read_history_file(histfile) readline.read_history_file(histfile)
readline.set_history_length(1000) readline.set_history_length(1000)
except FileNotFoundError: except FileNotFoundError:
pass pass
atexit.register(readline.write_history_file, histfile) atexit.register(readline.write_history_file, histfile)

View File

@ -5,27 +5,44 @@ class LambdaWarmUpCosineScheduler:
""" """
note: use with a base_lr of 1.0 note: use with a base_lr of 1.0
""" """
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start self.lr_start = lr_start
self.lr_min = lr_min self.lr_min = lr_min
self.lr_max = lr_max self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0. self.last_lr = 0.0
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_lr}'
)
if n < self.lr_warm_up_steps: if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr self.last_lr = lr
return lr return lr
else: else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0) t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi)) 1 + np.cos(t * np.pi)
)
self.last_lr = lr self.last_lr = lr
return lr return lr
@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0. note: use with a base_lr of 1.0.
""" """
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) def __init__(
self,
warm_up_steps,
f_min,
f_max,
f_start,
cycle_lengths,
verbosity_interval=0,
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start self.f_start = f_start
self.f_min = f_min self.f_min = f_min
self.f_max = f_max self.f_max = f_max
self.cycle_lengths = cycle_lengths self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0. self.last_f = 0.0
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
def find_in_interval(self, n): def find_in_interval(self, n):
@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2:
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " if n % self.verbosity_interval == 0:
f"current cycle {cycle}") print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f self.last_f = f
return f return f
else: else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0) t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( f = self.f_min[cycle] + 0.5 * (
1 + np.cos(t * np.pi)) self.f_max[cycle] - self.f_min[cycle]
) * (1 + np.cos(t * np.pi))
self.last_f = f self.last_f = f
return f return f
@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2:
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " if n % self.verbosity_interval == 0:
f"current cycle {cycle}") print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f self.last_f = f
return f return f
else: else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f self.last_f = f
return f return f

View File

@ -6,20 +6,23 @@ from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self, def __init__(
self,
ddconfig, ddconfig,
lossconfig, lossconfig,
n_embed, n_embed,
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=[],
image_key="image", image_key='image',
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
batch_resize_range=None, batch_resize_range=None,
@ -27,7 +30,7 @@ class VQModel(pl.LightningModule):
lr_g_factor=1.0, lr_g_factor=1.0,
remap=None, remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False use_ema=False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -36,24 +39,34 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap, remap=remap,
sane_index_shape=sane_index_shape) sane_index_shape=sane_index_shape,
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) )
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels) == int assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.batch_resize_range = batch_resize_range self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@ -66,28 +79,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") print(f'{context}: Switched to EMA weights')
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0: if len(missing) > 0:
print(f"Missing Keys: {missing}") print(f'Missing Keys: {missing}')
print(f"Unexpected Keys: {unexpected}") print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
@ -125,7 +140,11 @@ class VQModel(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0] lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1] upper_size = self.batch_resize_range[1]
@ -133,9 +152,11 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom # do the first few batches with max size to avoid later oom
new_resize = upper_size new_resize = upper_size
else: else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]: if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic") x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach() x = x.detach()
return x return x
@ -147,49 +168,99 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0: if optimizer_idx == 0:
# autoencode # autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="train", qloss,
predicted_indices=ind) x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return aeloss return aeloss
if optimizer_idx == 1: if optimizer_idx == 1:
# discriminator # discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="train") qloss,
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return discloss return discloss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") log_dict_ema = self._validation_step(
batch, batch_idx, suffix='_ema'
)
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True) xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split="val"+suffix, split='val' + suffix,
predicted_indices=ind predicted_indices=ind,
) )
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split="val"+suffix, split='val' + suffix,
predicted_indices=ind predicted_indices=ind,
)
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
) )
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'): if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"] del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
@ -197,31 +268,39 @@ class VQModel(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
lr_d = self.learning_rate lr_d = self.learning_rate
lr_g = self.lr_g_factor * self.learning_rate lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d) print('lr_d', lr_d)
print("lr_g", lr_g) print('lr_g', lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ opt_ae = torch.optim.Adam(
list(self.decoder.parameters())+ list(self.encoder.parameters())
list(self.quantize.parameters())+ + list(self.decoder.parameters())
list(self.quant_conv.parameters())+ + list(self.quantize.parameters())
list(self.post_quant_conv.parameters()), + list(self.quant_conv.parameters())
lr=lr_g, betas=(0.5, 0.9)) + list(self.post_quant_conv.parameters()),
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr_g,
lr=lr_d, betas=(0.5, 0.9)) betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None: if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...") print('Setting up LambdaLR scheduler...')
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'scheduler': LambdaLR(
opt_ae, lr_lambda=scheduler.schedule
),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1,
}, },
{ {
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'scheduler': LambdaLR(
opt_disc, lr_lambda=scheduler.schedule
),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1,
}, },
] ]
return [opt_ae, opt_disc], scheduler return [opt_ae, opt_disc], scheduler
@ -235,7 +314,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
log["inputs"] = x log['inputs'] = x
return log return log
xrec, _ = self(x) xrec, _ = self(x)
if x.shape[1] > 3: if x.shape[1] > 3:
@ -243,21 +322,24 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log["inputs"] = x log['inputs'] = x
log["reconstructions"] = xrec log['reconstructions'] = xrec
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
log["reconstructions_ema"] = xrec_ema xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == "segmentation" assert self.image_key == 'segmentation'
if not hasattr(self, "colorize"): if not hasattr(self, 'colorize'):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1. x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x
@ -283,13 +365,14 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule): class AutoencoderKL(pl.LightningModule):
def __init__(self, def __init__(
self,
ddconfig, ddconfig,
lossconfig, lossconfig,
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=[],
image_key="image", image_key='image',
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
): ):
@ -298,28 +381,34 @@ class AutoencoderKL(pl.LightningModule):
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"] assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 2 * ddconfig['z_channels'], 2 * embed_dim, 1
)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels) == int assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
self.load_state_dict(sd, strict=False) self.load_state_dict(sd, strict=False)
print(f"Restored from {path}") print(f'Restored from {path}')
def encode(self, x): def encode(self, x):
h = self.encoder(x) h = self.encoder(x)
@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
return x return x
def training_step(self, batch, batch_idx, optimizer_idx): def training_step(self, batch, batch_idx, optimizer_idx):
@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0: if optimizer_idx == 0:
# train encoder+decoder+logvar # train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="train") inputs,
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) reconstructions,
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return aeloss return aeloss
if optimizer_idx == 1: if optimizer_idx == 1:
# train the discriminator # train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="train") inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log(
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return discloss return discloss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key) inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs) reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="val") inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="val") inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
def configure_optimizers(self): def configure_optimizers(self):
lr = self.learning_rate lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ opt_ae = torch.optim.Adam(
list(self.decoder.parameters())+ list(self.encoder.parameters())
list(self.quant_conv.parameters())+ + list(self.decoder.parameters())
list(self.post_quant_conv.parameters()), + list(self.quant_conv.parameters())
lr=lr, betas=(0.5, 0.9)) + list(self.post_quant_conv.parameters()),
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr,
lr=lr, betas=(0.5, 0.9)) betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], [] return [opt_ae, opt_disc], []
def get_last_layer(self): def get_last_layer(self):
@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample())) log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec log['reconstructions'] = xrec
log["inputs"] = x log['inputs'] = x
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == "segmentation" assert self.image_key == 'segmentation'
if not hasattr(self, "colorize"): if not hasattr(self, 'colorize'):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1. x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x

View File

@ -10,13 +10,13 @@ from einops import rearrange
from glob import glob from glob import glob
from natsort import natsorted from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel from ldm.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = { __models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
def disabled_train(self, mode=True): def disabled_train(self, mode=True):
@ -26,8 +26,8 @@ def disabled_train(self, mode=True):
class NoisyLatentImageClassifier(pl.LightningModule): class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(
def __init__(self, self,
diffusion_path, diffusion_path,
num_classes, num_classes,
ckpt_path=None, ckpt_path=None,
@ -35,28 +35,40 @@ class NoisyLatentImageClassifier(pl.LightningModule):
label_key=None, label_key=None,
diffusion_ckpt_path=None, diffusion_ckpt_path=None,
scheduler_config=None, scheduler_config=None,
weight_decay=1.e-2, weight_decay=1.0e-2,
log_steps=10, log_steps=10,
monitor='val/loss', monitor='val/loss',
*args, *args,
**kwargs): **kwargs,
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.num_classes = num_classes self.num_classes = num_classes
# get latest config of diffusion model # get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] diffusion_config = natsorted(
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion() self.load_diffusion()
self.monitor = monitor self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 self.numd = (
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
)
self.log_time_interval = (
self.diffusion_model.num_timesteps // log_steps
)
self.log_steps = log_steps self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ self.label_key = (
label_key
if not hasattr(self.diffusion_model, 'cond_stage_key')
else self.diffusion_model.cond_stage_key else self.diffusion_model.cond_stage_key
)
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' assert (
self.label_key is not None
), 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__: if self.label_key not in __models__:
raise NotImplementedError() raise NotImplementedError()
@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.weight_decay = weight_decay self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location='cpu')
if "state_dict" in list(sd.keys()): if 'state_dict' in list(sd.keys()):
sd = sd["state_dict"] sd = sd['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = (
sd, strict=False) self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0: if len(missing) > 0:
print(f"Missing Keys: {missing}") print(f'Missing Keys: {missing}')
if len(unexpected) > 0: if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}") print(f'Unexpected Keys: {unexpected}')
def load_diffusion(self): def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config) model = instantiate_from_config(self.diffusion_config)
@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule):
param.requires_grad = False param.requires_grad = False
def load_classifier(self, ckpt_path, pool): def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params) model_config = deepcopy(
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels self.diffusion_config.params.unet_config.params
)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes model_config.out_channels = self.num_classes
if self.label_key == 'class_label': if self.label_key == 'class_label':
model_config.pool = pool model_config.pool = pool
self.model = __models__[self.label_key](**model_config) self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None: if ckpt_path is not None:
print('#####################################################################') print(
'#####################################################################'
)
print(f'load from ckpt "{ckpt_path}"') print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################') print(
'#####################################################################'
)
self.init_from_ckpt(ckpt_path) self.init_from_ckpt(ckpt_path)
@torch.no_grad() @torch.no_grad()
@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule):
noise = default(noise, lambda: torch.randn_like(x)) noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise: if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(
x.shape[0], t + 1
)
)
# todo: make sure t+1 is correct here # todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, return self.diffusion_model.q_sample(
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs): def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t) return self.model(x_noisy, t)
@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule):
targets = rearrange(targets, 'b h w c -> b c h w') targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd): for down in range(self.numd):
h, w = targets.shape[-2:] h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') targets = F.interpolate(
targets, size=(h // 2, w // 2), mode='nearest'
)
# targets = rearrange(targets,'b c h w -> b h w c') # targets = rearrange(targets,'b c h w -> b h w c')
return targets return targets
def compute_top_k(self, logits, labels, k, reduction="mean"): def compute_top_k(self, logits, labels, k, reduction='mean'):
_, top_ks = torch.topk(logits, k, dim=1) _, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean": if reduction == 'mean':
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() return (
elif reduction == "none": (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
)
elif reduction == 'none':
return (top_ks == labels[:, None]).float().sum(dim=-1) return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self): def on_train_epoch_start(self):
@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def write_logs(self, loss, logits, targets): def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val' log_prefix = 'train' if self.training else 'val'
log = {} log = {}
log[f"{log_prefix}/loss"] = loss.mean() log[f'{log_prefix}/loss'] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k( log[f'{log_prefix}/acc@1'] = self.compute_top_k(
logits, targets, k=1, reduction="mean" logits, targets, k=1, reduction='mean'
) )
log[f"{log_prefix}/acc@5"] = self.compute_top_k( log[f'{log_prefix}/acc@5'] = self.compute_top_k(
logits, targets, k=5, reduction="mean" logits, targets, k=5, reduction='mean'
) )
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) self.log_dict(
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) log,
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
self.log(
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
)
self.log(
'global_step',
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
lr = self.optimizers().param_groups[0]['lr'] lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) self.log(
'lr_abs',
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None): def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch) targets = self.get_conditioning(batch)
if targets.dim() == 4: if targets.dim() == 4:
targets = targets.argmax(dim=1) targets = targets.argmax(dim=1)
if t is None: if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else: else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() t = torch.full(
size=(x.shape[0],), fill_value=t, device=self.device
).long()
x_noisy = self.get_x_noisy(x, t) x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t) logits = self(x_noisy, t)
@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
return loss return loss
def reset_noise_accs(self): def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in self.noisy_acc = {
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} t: {'acc@1': [], 'acc@5': []}
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self): def on_validation_start(self):
self.reset_noise_accs() self.reset_noise_accs()
@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule):
for t in self.noisy_acc: for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t) _, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) self.noisy_acc[t]['acc@1'].append(
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) self.compute_top_k(logits, targets, k=1, reduction='mean')
)
self.noisy_acc[t]['acc@5'].append(
self.compute_top_k(logits, targets, k=5, reduction='mean')
)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler: if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...") print('Setting up LambdaLR scheduler...')
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 'scheduler': LambdaLR(
optimizer, lr_lambda=scheduler.schedule
),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1,
}] }
]
return [optimizer], scheduler return [optimizer], scheduler
return optimizer return optimizer
@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
y = self.get_conditioning(batch) y = self.get_conditioning(batch)
if self.label_key == 'class_label': if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
log['labels'] = y log['labels'] = y
if ismap(y): if ismap(y):
@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
log[f'inputs@t{current_time}'] = x_noisy log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) pred = F.one_hot(
logits.argmax(dim=1), num_classes=self.num_classes
)
pred = rearrange(pred, 'b h w c -> b c h w') pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
pred
)
for key in log: for key in log:
log[key] = log[key][:N] log[key] = log[key][:N]

View File

@ -5,12 +5,16 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ from ldm.modules.diffusionmodules.util import (
extract_into_tensor make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object): class DDIMSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs): def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
@ -23,39 +27,87 @@ class DDIMSampler(object):
attr = attr.to(torch.device(self.device)) attr = attr.to(torch.device(self.device))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, self,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' assert (
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) )
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters # ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), (
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose) eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( (1 - self.alphas_cumprod_prev)
1 - self.alphas_cumprod / self.alphas_cumprod_prev)) / (1 - self.alphas_cumprod)
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
self,
S, S,
batch_size, batch_size,
shape, shape,
@ -64,29 +116,33 @@ class DDIMSampler(object):
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0., eta=0.0,
mask=None, mask=None,
x0=None, x0=None,
temperature=1., temperature=1.0,
noise_dropout=0., noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
verbose=True, verbose=True,
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs **kwargs,
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling # sampling
@ -94,11 +150,14 @@ class DDIMSampler(object):
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}') print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size, samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback, callback=callback,
img_callback=img_callback, img_callback=img_callback,
quantize_denoised=quantize_x0, quantize_denoised=quantize_x0,
mask=mask, x0=x0, mask=mask,
x0=x0,
ddim_use_original_steps=False, ddim_use_original_steps=False,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
temperature=temperature, temperature=temperature,
@ -112,12 +171,26 @@ class DDIMSampler(object):
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def ddim_sampling(self, cond, shape, def ddim_sampling(
x_T=None, ddim_use_original_steps=False, self,
callback=None, timesteps=None, quantize_denoised=False, cond,
mask=None, x0=None, img_callback=None, log_every_t=100, shape,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -126,17 +199,38 @@ class DDIMSampler(object):
img = x_T img = x_T
if timesteps is None: if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps: elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end] timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) time_range = (
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] reversed(range(0, timesteps))
print(f"Running DDIM Sampling with {total_steps} timesteps") if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True) iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
@ -144,18 +238,30 @@ class DDIMSampler(object):
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.model.q_sample(
img = img_orig * mask + (1. - mask) * img x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, outs = self.p_sample_ddim(
quantize_denoised=quantize_denoised, temperature=temperature, img,
noise_dropout=noise_dropout, score_corrector=score_corrector, cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback:
if img_callback: img_callback(pred_x0, i) callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates['x_inter'].append(img)
@ -164,42 +270,82 @@ class DDIMSampler(object):
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, self,
unconditional_guidance_scale=1., unconditional_conditioning=None): x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == "eps" assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = (
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev self.model.alphas_cumprod
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas if use_original_steps
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = (
if noise_dropout > 0.: sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
@ -217,26 +363,51 @@ class DDIMSampler(object):
if noise is None: if noise is None:
noise = torch.randn_like(x0) noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + return (
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad() @torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def decode(
use_original_steps=False): self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps") print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent x_dec = x_latent
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) ts = torch.full(
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, (x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning,
)
return x_dec return x_dec

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,9 @@
'''wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers''' """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K import k_diffusion as K
import torch import torch
import torch.nn as nn import torch.nn as nn
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
@ -15,8 +16,9 @@ class CFGDenoiser(nn.Module):
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class KSampler(object): class KSampler(object):
def __init__(self, model, schedule="lms", device="cuda", **kwargs): def __init__(self, model, schedule='lms', device='cuda', **kwargs):
super().__init__() super().__init__()
self.model = K.external.CompVisDenoiser(model) self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule self.schedule = schedule
@ -26,14 +28,16 @@ class KSampler(object):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with # most of these arguments are ignored and are only present for compatibility with
# other samples # other samples
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
self,
S, S,
batch_size, batch_size,
shape, shape,
@ -42,28 +46,39 @@ class KSampler(object):
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0., eta=0.0,
mask=None, mask=None,
x0=None, x0=None,
temperature=1., temperature=1.0,
noise_dropout=0., noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
verbose=True, verbose=True,
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs **kwargs,
): ):
sigmas = self.model.get_sigmas(S) sigmas = self.model.get_sigmas(S)
if x_T: if x_T:
x = x_T x = x_T
else: else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model) model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} extra_args = {
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args), 'cond': conditioning,
None) 'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
),
None,
)

View File

@ -5,11 +5,15 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object): class PLMSSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs): def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
@ -23,41 +27,89 @@ class PLMSSampler(object):
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0: if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS') raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, self.ddim_timesteps = make_ddim_timesteps(
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' assert (
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) )
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters # ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), (
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose) eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( (1 - self.alphas_cumprod_prev)
1 - self.alphas_cumprod / self.alphas_cumprod_prev)) / (1 - self.alphas_cumprod)
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
self,
S, S,
batch_size, batch_size,
shape, shape,
@ -66,29 +118,33 @@ class PLMSSampler(object):
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0., eta=0.0,
mask=None, mask=None,
x0=None, x0=None,
temperature=1., temperature=1.0,
noise_dropout=0., noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
verbose=True, verbose=True,
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs **kwargs,
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling # sampling
@ -96,11 +152,14 @@ class PLMSSampler(object):
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}') # print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size, samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback, callback=callback,
img_callback=img_callback, img_callback=img_callback,
quantize_denoised=quantize_x0, quantize_denoised=quantize_x0,
mask=mask, x0=x0, mask=mask,
x0=x0,
ddim_use_original_steps=False, ddim_use_original_steps=False,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
temperature=temperature, temperature=temperature,
@ -114,12 +173,26 @@ class PLMSSampler(object):
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def plms_sampling(self, cond, shape, def plms_sampling(
x_T=None, ddim_use_original_steps=False, self,
callback=None, timesteps=None, quantize_denoised=False, cond,
mask=None, x0=None, img_callback=None, log_every_t=100, shape,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -128,42 +201,81 @@ class PLMSSampler(object):
img = x_T img = x_T
if timesteps is None: if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps: elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end] timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) time_range = (
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps") # print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True) iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = [] old_eps = []
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long) ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.model.q_sample(
img = img_orig * mask + (1. - mask) * img x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, outs = self.p_sample_plms(
quantize_denoised=quantize_denoised, temperature=temperature, img,
noise_dropout=noise_dropout, score_corrector=score_corrector, cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next) old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs img, pred_x0, e_t = outs
old_eps.append(e_t) old_eps.append(e_t)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)
if callback: callback(i) if callback:
if img_callback: img_callback(pred_x0, i) callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates['x_inter'].append(img)
@ -172,47 +284,95 @@ class PLMSSampler(object):
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, self,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t_uncond, e_t = self.model.apply_model(
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == "eps" assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = (
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev self.model.alphas_cumprod
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas if use_original_steps
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) a_prev = torch.full(
(b, 1, 1, 1), alphas_prev[index], device=device
)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = (
if noise_dropout > 0.: sigma_t
* noise_like(x.shape, device, repeat_noise)
* temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
@ -231,7 +391,12 @@ class PLMSSampler(object):
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3: elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 e_t_prime = (
55 * e_t
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

View File

@ -45,19 +45,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = (
nn.Linear(dim, inner_dim), nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
nn.GELU() if not glu
) if not glu else GEGLU(dim, inner_dim) else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
) )
def forward(self, x): def forward(self, x):
@ -74,7 +73,9 @@ def zero_module(module):
def Normalize(in_channels): def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module): class LinearAttention(nn.Module):
@ -88,11 +89,22 @@ class LinearAttention(nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) q, k, v = rearrange(
qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3,
)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w,
)
return self.to_out(out) return self.to_out(out)
@ -102,26 +114,18 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
kernel_size=1, )
stride=1, self.k = torch.nn.Conv2d(
padding=0) in_channels, in_channels, kernel_size=1, stride=1, padding=0
self.k = torch.nn.Conv2d(in_channels, )
in_channels, self.v = torch.nn.Conv2d(
kernel_size=1, in_channels, in_channels, kernel_size=1, stride=1, padding=0
stride=1, )
padding=0) self.proj_out = torch.nn.Conv2d(
self.v = torch.nn.Conv2d(in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels, )
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -150,7 +154,9 @@ class SpatialSelfAttention(nn.Module):
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -163,8 +169,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
@ -175,7 +180,9 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)
)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@ -194,19 +201,37 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x)) + x
@ -223,29 +248,43 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(
inner_dim, in_channels, inner_dim, kernel_size=1, stride=1, padding=0
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
) )
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.transformer_blocks = nn.ModuleList(
in_channels, [
kernel_size=1, BasicTransformerBlock(
stride=1, inner_dim,
padding=0)) n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
)
)
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention

View File

@ -36,7 +36,9 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module): class Upsample(nn.Module):
@ -44,14 +46,14 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, self.conv = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
def forward(self, x): def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest'
)
if self.with_conv: if self.with_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -63,16 +65,14 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, self.conv = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=3, stride=2, padding=0
kernel_size=3, )
stride=2,
padding=0)
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -80,8 +80,15 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, def __init__(
dropout, temb_channels=512): self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
@ -89,34 +96,33 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels) self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, self.conv1 = torch.nn.Conv2d(
out_channels, in_channels, out_channels, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
if temb_channels > 0: if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, out_channels, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1,
)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0,
)
def forward(self, x, temb): def forward(self, x, temb):
h = x h = x
@ -143,6 +149,7 @@ class ResnetBlock(nn.Module):
class LinAttnBlock(LinearAttention): class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage""" """to match AttnBlock usage"""
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels) super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
@ -153,27 +160,18 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
kernel_size=1, )
stride=1, self.k = torch.nn.Conv2d(
padding=0) in_channels, in_channels, kernel_size=1, stride=1, padding=0
self.k = torch.nn.Conv2d(in_channels, )
in_channels, self.v = torch.nn.Conv2d(
kernel_size=1, in_channels, in_channels, kernel_size=1, stride=1, padding=0
stride=1, )
padding=0) self.proj_out = torch.nn.Conv2d(
self.v = torch.nn.Conv2d(in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels, )
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -194,7 +192,9 @@ class AttnBlock(nn.Module):
# attend to values # attend to values
v = v.reshape(b, c, h * w) v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = torch.bmm(
v, w_
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w) h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
@ -202,23 +202,43 @@ class AttnBlock(nn.Module):
return x + h_ return x + h_
def make_attn(in_channels, attn_type="vanilla"): def make_attn(in_channels, attn_type='vanilla'):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in [
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 'vanilla',
if attn_type == "vanilla": 'linear',
'none',
], f'attn_type {attn_type} unknown'
print(
f"making attention of type '{attn_type}' with {in_channels} in_channels"
)
if attn_type == 'vanilla':
return AttnBlock(in_channels) return AttnBlock(in_channels)
elif attn_type == "none": elif attn_type == 'none':
return nn.Identity(in_channels) return nn.Identity(in_channels)
else: else:
return LinAttnBlock(in_channels) return LinAttnBlock(in_channels)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, self,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): *,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type='vanilla',
):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn:
attn_type = 'linear'
self.ch = ch self.ch = ch
self.temb_ch = self.ch * 4 self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -230,19 +250,17 @@ class Model(nn.Module):
if self.use_timestep: if self.use_timestep:
# timestep embedding # timestep embedding
self.temb = nn.Module() self.temb = nn.Module()
self.temb.dense = nn.ModuleList([ self.temb.dense = nn.ModuleList(
torch.nn.Linear(self.ch, [
self.temb_ch), torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, torch.nn.Linear(self.temb_ch, self.temb_ch),
self.temb_ch), ]
]) )
# downsampling # downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.conv_in = torch.nn.Conv2d(
self.ch, in_channels, self.ch, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
curr_res = resolution curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult) in_ch_mult = (1,) + tuple(ch_mult)
@ -253,10 +271,14 @@ class Model(nn.Module):
block_in = ch * in_ch_mult[i_level] block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
)
)
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(make_attn(block_in, attn_type=attn_type))
@ -270,15 +292,19 @@ class Model(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
)
# upsampling # upsampling
self.up = nn.ModuleList() self.up = nn.ModuleList()
@ -290,10 +316,14 @@ class Model(nn.Module):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks: if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level] skip_in = ch * in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in, block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
)
)
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(make_attn(block_in, attn_type=attn_type))
@ -307,11 +337,9 @@ class Model(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = torch.nn.Conv2d(
out_ch, block_in, out_ch, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
def forward(self, x, t=None, context=None): def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution # assert x.shape[2] == x.shape[3] == self.resolution
@ -349,7 +377,8 @@ class Model(nn.Module):
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block]( h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb) torch.cat([h, hs.pop()], dim=1), temb
)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h) h = self.up[i_level].attn[i_block](h)
if i_level != 0: if i_level != 0:
@ -366,12 +395,27 @@ class Model(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, self,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", *,
**ignore_kwargs): ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs,
):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn:
attn_type = 'linear'
self.ch = ch self.ch = ch
self.temb_ch = 0 self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -380,11 +424,9 @@ class Encoder(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
# downsampling # downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.conv_in = torch.nn.Conv2d(
self.ch, in_channels, self.ch, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
curr_res = resolution curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult) in_ch_mult = (1,) + tuple(ch_mult)
@ -396,10 +438,14 @@ class Encoder(nn.Module):
block_in = ch * in_ch_mult[i_level] block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
)
)
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(make_attn(block_in, attn_type=attn_type))
@ -413,23 +459,29 @@ class Encoder(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
)
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels, 2 * z_channels if double_z else z_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1,
)
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
@ -460,12 +512,28 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, self,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, *,
attn_type="vanilla", **ignorekwargs): ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs,
):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn:
attn_type = 'linear'
self.ch = ch self.ch = ch
self.temb_ch = 0 self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -480,27 +548,32 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format( print(
self.z_shape, np.prod(self.z_shape))) 'Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in # z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, self.conv_in = torch.nn.Conv2d(
block_in, z_channels, block_in, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
)
# upsampling # upsampling
self.up = nn.ModuleList() self.up = nn.ModuleList()
@ -509,10 +582,14 @@ class Decoder(nn.Module):
attn = nn.ModuleList() attn = nn.ModuleList()
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
)
)
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(make_attn(block_in, attn_type=attn_type))
@ -526,11 +603,9 @@ class Decoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = torch.nn.Conv2d(
out_ch, block_in, out_ch, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
def forward(self, z): def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:] # assert z.shape[1:] == self.z_shape[1:]
@ -571,25 +646,36 @@ class Decoder(nn.Module):
class SimpleDecoder(nn.Module): class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs): def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__() super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), self.model = nn.ModuleList(
ResnetBlock(in_channels=in_channels, [
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(
in_channels=in_channels,
out_channels=2 * in_channels, out_channels=2 * in_channels,
temb_channels=0, dropout=0.0), temb_channels=0,
ResnetBlock(in_channels=2 * in_channels, dropout=0.0,
),
ResnetBlock(
in_channels=2 * in_channels,
out_channels=4 * in_channels, out_channels=4 * in_channels,
temb_channels=0, dropout=0.0), temb_channels=0,
ResnetBlock(in_channels=4 * in_channels, dropout=0.0,
),
ResnetBlock(
in_channels=4 * in_channels,
out_channels=2 * in_channels, out_channels=2 * in_channels,
temb_channels=0, dropout=0.0), temb_channels=0,
dropout=0.0,
),
nn.Conv2d(2 * in_channels, in_channels, 1), nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)]) Upsample(in_channels, with_conv=True),
]
)
# end # end
self.norm_out = Normalize(in_channels) self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels, self.conv_out = torch.nn.Conv2d(
out_channels, in_channels, out_channels, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
def forward(self, x): def forward(self, x):
for i, layer in enumerate(self.model): for i, layer in enumerate(self.model):
@ -605,8 +691,16 @@ class SimpleDecoder(nn.Module):
class UpsampleDecoder(nn.Module): class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, def __init__(
ch_mult=(2,2), dropout=0.0): self,
in_channels,
out_channels,
ch,
num_res_blocks,
resolution,
ch_mult=(2, 2),
dropout=0.0,
):
super().__init__() super().__init__()
# upsampling # upsampling
self.temb_ch = 0 self.temb_ch = 0
@ -620,10 +714,14 @@ class UpsampleDecoder(nn.Module):
res_block = [] res_block = []
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in, res_block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
)
)
block_in = block_out block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block)) self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
@ -632,11 +730,9 @@ class UpsampleDecoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = torch.nn.Conv2d(
out_channels, block_in, out_channels, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1,
padding=1)
def forward(self, x): def forward(self, x):
# upsampling # upsampling
@ -653,26 +749,41 @@ class UpsampleDecoder(nn.Module):
class LatentRescaler(nn.Module): class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): def __init__(
self, factor, in_channels, mid_channels, out_channels, depth=2
):
super().__init__() super().__init__()
# residual block, interpolate, residual block # residual block, interpolate, residual block
self.factor = factor self.factor = factor
self.conv_in = nn.Conv2d(in_channels, self.conv_in = nn.Conv2d(
mid_channels, in_channels, mid_channels, kernel_size=3, stride=1, padding=1
kernel_size=3, )
stride=1, self.res_block1 = nn.ModuleList(
padding=1) [
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels, out_channels=mid_channels,
temb_channels=0, temb_channels=0,
dropout=0.0) for _ in range(depth)]) dropout=0.0,
)
for _ in range(depth)
]
)
self.attn = AttnBlock(mid_channels) self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, self.res_block2 = nn.ModuleList(
[
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels, out_channels=mid_channels,
temb_channels=0, temb_channels=0,
dropout=0.0) for _ in range(depth)]) dropout=0.0,
)
for _ in range(depth)
]
)
self.conv_out = nn.Conv2d(mid_channels, self.conv_out = nn.Conv2d(
mid_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
) )
@ -681,7 +792,13 @@ class LatentRescaler(nn.Module):
x = self.conv_in(x) x = self.conv_in(x)
for block in self.res_block1: for block in self.res_block1:
x = block(x, None) x = block(x, None)
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) x = torch.nn.functional.interpolate(
x,
size=(
int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor)),
),
)
x = self.attn(x) x = self.attn(x)
for block in self.res_block2: for block in self.res_block2:
x = block(x, None) x = block(x, None)
@ -690,17 +807,42 @@ class LatentRescaler(nn.Module):
class MergedRescaleEncoder(nn.Module): class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, def __init__(
attn_resolutions, dropout=0.0, resamp_with_conv=True, self,
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__() super().__init__()
intermediate_chn = ch * ch_mult[-1] intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, self.encoder = Encoder(
z_channels=intermediate_chn, double_z=False, resolution=resolution, in_channels=in_channels,
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, num_res_blocks=num_res_blocks,
out_ch=None) ch=ch,
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, ch_mult=ch_mult,
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth,
)
def forward(self, x): def forward(self, x):
x = self.encoder(x) x = self.encoder(x)
@ -709,15 +851,41 @@ class MergedRescaleEncoder(nn.Module):
class MergedRescaleDecoder(nn.Module): class MergedRescaleDecoder(nn.Module):
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), def __init__(
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__() super().__init__()
tmp_chn = z_channels * ch_mult[-1] tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, self.decoder = Decoder(
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, out_ch=out_ch,
ch_mult=ch_mult, resolution=resolution, ch=ch) z_channels=tmp_chn,
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, attn_resolutions=attn_resolutions,
out_channels=tmp_chn, depth=rescale_module_depth) dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth,
)
def forward(self, x): def forward(self, x):
x = self.rescaler(x) x = self.rescaler(x)
@ -726,17 +894,32 @@ class MergedRescaleDecoder(nn.Module):
class Upsampler(nn.Module): class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): def __init__(
self, in_size, out_size, in_channels, out_channels, ch_mult=2
):
super().__init__() super().__init__()
assert out_size >= in_size assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1 num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1.+ (out_size % in_size) factor_up = 1.0 + (out_size % in_size)
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") print(
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}'
out_channels=in_channels) )
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, self.rescaler = LatentRescaler(
attn_resolutions=[], in_channels=None, ch=in_channels, factor=factor_up,
ch_mult=[ch_mult for _ in range(num_blocks)]) in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels,
)
self.decoder = Decoder(
out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)],
)
def forward(self, x): def forward(self, x):
x = self.rescaler(x) x = self.rescaler(x)
@ -745,42 +928,55 @@ class Upsampler(nn.Module):
class Resize(nn.Module): class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"): def __init__(self, in_channels=None, learned=False, mode='bilinear'):
super().__init__() super().__init__()
self.with_conv = learned self.with_conv = learned
self.mode = mode self.mode = mode
if self.with_conv: if self.with_conv:
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") print(
f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode'
)
raise NotImplementedError() raise NotImplementedError()
assert in_channels is not None assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, self.conv = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=4, stride=2, padding=1
kernel_size=4, )
stride=2,
padding=1)
def forward(self, x, scale_factor=1.0): def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0: if scale_factor == 1.0:
return x return x
else: else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) x = torch.nn.functional.interpolate(
x,
mode=self.mode,
align_corners=False,
scale_factor=scale_factor,
)
return x return x
class FirstStagePostProcessor(nn.Module):
def __init__(self, ch_mult:list, in_channels, class FirstStagePostProcessor(nn.Module):
def __init__(
self,
ch_mult: list,
in_channels,
pretrained_model: nn.Module = None, pretrained_model: nn.Module = None,
reshape=False, reshape=False,
n_channels=None, n_channels=None,
dropout=0., dropout=0.0,
pretrained_config=None): pretrained_config=None,
):
super().__init__() super().__init__()
if pretrained_config is None: if pretrained_config is None:
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' assert (
pretrained_model is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model self.pretrained_model = pretrained_model
else: else:
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' assert (
pretrained_config is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config) self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape self.do_reshape = reshape
@ -789,21 +985,27 @@ class FirstStagePostProcessor(nn.Module):
n_channels = self.pretrained_model.encoder.ch n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, self.proj = nn.Conv2d(
stride=1,padding=1) in_channels, n_channels, kernel_size=3, stride=1, padding=1
)
blocks = [] blocks = []
downs = [] downs = []
ch_in = n_channels ch_in = n_channels
for m in ch_mult: for m in ch_mult:
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) blocks.append(
ResnetBlock(
in_channels=ch_in,
out_channels=m * n_channels,
dropout=dropout,
)
)
ch_in = m * n_channels ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False)) downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks) self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs) self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config): def instantiate_pretrained(self, config):
model = instantiate_from_config(config) model = instantiate_from_config(config)
self.pretrained_model = model.eval() self.pretrained_model = model.eval()
@ -811,7 +1013,6 @@ class FirstStagePostProcessor(nn.Module):
for param in self.pretrained_model.parameters(): for param in self.pretrained_model.parameters():
param.requires_grad = False param.requires_grad = False
@torch.no_grad() @torch.no_grad()
def encode_with_pretrained(self, x): def encode_with_pretrained(self, x):
c = self.pretrained_model.encode(x) c = self.pretrained_model.encode(x)
@ -832,4 +1033,3 @@ class FirstStagePostProcessor(nn.Module):
if self.do_reshape: if self.do_reshape:
z = rearrange(z, 'b c h w -> b (h w) c') z = rearrange(z, 'b c h w -> b (h w) c')
return z return z

View File

@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer
def convert_module_to_f16(x): def convert_module_to_f16(x):
pass pass
def convert_module_to_f32(x): def convert_module_to_f32(x):
pass pass
@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module):
output_dim: int = None, output_dim: int = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels self.num_heads = embed_dim // num_heads_channels
@ -97,35 +100,43 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
) )
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
class TransposedUpsample(nn.Module): class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding' """Learned 2x upsampling without padding"""
def __init__(self, channels, out_channels=None, ks=5): def __init__(self, channels, out_channels=None, ks=5):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
def forward(self, x): def forward(self, x):
return self.up(x) return self.up(x)
@ -140,7 +151,9 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -149,7 +162,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
) )
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
@ -219,7 +237,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, 2 * self.out_channels
if use_scale_shift_norm
else self.out_channels,
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
@ -227,7 +247,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1
)
), ),
) )
@ -238,7 +260,9 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, 3, padding=1 dims, channels, self.out_channels, 3, padding=1
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1
)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -251,7 +275,6 @@ class ResBlock(TimestepBlock):
self._forward, (x, emb), self.parameters(), self.use_checkpoint self._forward, (x, emb), self.parameters(), self.use_checkpoint
) )
def _forward(self, x, emb): def _forward(self, x, emb):
if self.updown: if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@ -297,7 +320,7 @@ class AttentionBlock(nn.Module):
else: else:
assert ( assert (
channels % num_head_channels == 0 channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels) self.norm = normalization(channels)
@ -312,7 +335,9 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x): def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch # return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x): def _forward(self, x):
@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0 assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1
)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale 'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v) a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@ -397,12 +424,14 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1) q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
"bct,bcs->bts", 'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length), (q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) a = th.einsum(
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@ -469,11 +498,16 @@ class UNetModel(nn.Module):
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert (
context_dim is not None
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert (
use_spatial_transformer
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig: if type(context_dim) == ListConfig:
context_dim = list(context_dim) context_dim = list(context_dim)
@ -481,10 +515,14 @@ class UNetModel(nn.Module):
num_heads_upsample = num_heads num_heads_upsample = num_heads
if num_heads == -1: if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' assert (
num_head_channels != -1
), 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert (
num_heads != -1
), 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
@ -546,7 +584,11 @@ class UNetModel(nn.Module):
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
# num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@ -554,8 +596,14 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -593,7 +641,11 @@ class UNetModel(nn.Module):
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
# num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
ch, ch,
@ -609,8 +661,14 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
), ),
ResBlock( ResBlock(
ch, ch,
@ -647,7 +705,11 @@ class UNetModel(nn.Module):
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
# num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@ -655,8 +717,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
) )
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
@ -673,7 +741,9 @@ class UNetModel(nn.Module):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -682,7 +752,9 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)
),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
@ -718,9 +790,11 @@ class UNetModel(nn.Module):
""" """
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), 'must specify y if and only if the model is class-conditional'
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
@ -768,9 +842,9 @@ class EncoderUNetModel(nn.Module):
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
pool="adaptive", pool='adaptive',
*args, *args,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
@ -888,7 +962,7 @@ class EncoderUNetModel(nn.Module):
) )
self._feature_size += ch self._feature_size += ch
self.pool = pool self.pool = pool
if pool == "adaptive": if pool == 'adaptive':
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
@ -896,7 +970,7 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(), nn.Flatten(),
) )
elif pool == "attention": elif pool == 'attention':
assert num_head_channels != -1 assert num_head_channels != -1
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
@ -905,13 +979,13 @@ class EncoderUNetModel(nn.Module):
(image_size // ds), ch, num_head_channels, out_channels (image_size // ds), ch, num_head_channels, out_channels
), ),
) )
elif pool == "spatial": elif pool == 'spatial':
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
nn.ReLU(), nn.ReLU(),
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
elif pool == "spatial_v2": elif pool == 'spatial_v2':
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
normalization(2048), normalization(2048),
@ -919,7 +993,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
else: else:
raise NotImplementedError(f"Unexpected {pool} pooling") raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@ -942,20 +1016,21 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
if self.pool.startswith("spatial"): if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb) h = self.middle_block(h, emb)
if self.pool.startswith("spatial"): if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1) h = th.cat(results, axis=-1)
return self.out(h) return self.out(h)
else: else:
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) return self.out(h)

View File

@ -18,15 +18,24 @@ from einops import repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(
if schedule == "linear": schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == 'linear':
betas = ( betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64,
)
** 2
) )
elif schedule == "cosine": elif schedule == 'cosine':
timesteps = ( timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s
) )
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
@ -34,23 +43,41 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear": elif schedule == 'sqrt_linear':
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(
elif schedule == "sqrt": linear_start, linear_end, n_timestep, dtype=torch.float64
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 )
elif schedule == 'sqrt':
betas = (
torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
** 0.5
)
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == 'uniform': if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = (
(
np.linspace(
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
)
)
** 2
).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps # assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
@ -60,17 +87,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): def make_ddim_sampling_parameters(
alphacums, ddim_timesteps, eta, verbose=True
):
# select alphas for computing the variance schedule # select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) alphas_prev = np.asarray(
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
)
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose: if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') print(
print(f'For the chosen value of eta, which is {eta}, ' f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
f'this results in the following sigma_t schedule for ddim sampler {sigmas}') )
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev
@ -109,7 +146,9 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments. explicitly take as arguments.
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if False: # disabled checkpointing to allow requires_grad = False for main model if (
False
): # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params) args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args) return CheckpointFunction.apply(func, len(inputs), *args)
else: else:
@ -129,7 +168,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad(): with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d # Tensor storage in place, which is not allowed for detach()'d
@ -160,12 +201,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half -math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device) ).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else: else:
embedding = repeat(timesteps, 'b -> b d', d=dim) embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding return embedding
@ -215,6 +260,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
@ -225,7 +271,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs) return nn.Conv2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs): def linear(*args, **kwargs):
@ -245,15 +291,16 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs) return nn.AvgPool2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.AvgPool3d(*args, **kwargs) return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module): class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config): def __init__(self, c_concat_config, c_crossattn_config):
super().__init__() super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config) self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config
)
def forward(self, c_concat, c_crossattn): def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat) c_concat = self.concat_conditioner(c_concat)
@ -262,6 +309,8 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()

View File

@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar) self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar) self.var = torch.exp(self.logvar)
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self): def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x return x
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(
+ self.var - 1.0 - self.logvar, torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3]) dim=[1, 2, 3],
)
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, logtwopi
dim=dims) + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self): def mode(self):
return self.mean return self.mean
@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor = obj tensor = obj
break break
assert tensor is not None, "at least one argument must be a Tensor" assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp(). # Tensors, but it does not work for torch.exp().

View File

@ -10,8 +10,12 @@ class LitEma(nn.Module):
self.m_name2s_name = {} self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates self.register_buffer(
else torch.tensor(-1,dtype=torch.int)) 'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if p.requires_grad: if p.requires_grad:
@ -27,7 +31,9 @@ class LitEma(nn.Module):
if self.num_updates >= 0: if self.num_updates >= 0:
self.num_updates += 1 self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) decay = min(
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay one_minus_decay = 1.0 - decay
@ -38,8 +44,12 @@ class LitEma(nn.Module):
for key in m_param: for key in m_param:
if m_param[key].requires_grad: if m_param[key].requires_grad:
sname = self.m_name2s_name[key] sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname] = shadow_params[sname].type_as(
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) m_param[key]
)
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else: else:
assert not key in self.m_name2s_name assert not key in self.m_name2s_name
@ -48,7 +58,9 @@ class LitEma(nn.Module):
shadow_params = dict(self.named_buffers()) shadow_params = dict(self.named_buffers())
for key in m_param: for key in m_param:
if m_param[key].requires_grad: if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data
)
else: else:
assert not key in self.m_name2s_name assert not key in self.m_name2s_name

View File

@ -8,18 +8,29 @@ from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from functools import partial from functools import partial
DEFAULT_PLACEHOLDER_TOKEN = ["*"] DEFAULT_PLACEHOLDER_TOKEN = ['*']
PROGRESSIVE_SCALE = 2000 PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string): def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, batch_encoding = tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") string,
tokens = batch_encoding["input_ids"] truncation=True,
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids']
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1] return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string): def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string) token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
@ -28,6 +39,7 @@ def get_bert_token_for_string(tokenizer, string):
return token return token
def get_embedding_for_clip_token(embedder, token): def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0] return embedder(token.unsqueeze(0))[0, 0]
@ -41,7 +53,7 @@ class EmbeddingManager(nn.Module):
per_image_tokens=False, per_image_tokens=False,
num_vectors_per_token=1, num_vectors_per_token=1,
progressive_words=False, progressive_words=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
@ -49,21 +61,32 @@ class EmbeddingManager(nn.Module):
self.string_to_param_dict = nn.ParameterDict() self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized self.initial_embeddings = (
nn.ParameterDict()
) # These should not be optimized
self.progressive_words = progressive_words self.progressive_words = progressive_words
self.progressive_counter = 0 self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token self.max_vectors_per_token = num_vectors_per_token
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder if hasattr(
embedder, 'tokenizer'
): # using Stable Diffusion's CLIP encoder
self.is_clip = True self.is_clip = True
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) get_token_for_string = partial(
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) get_clip_token_for_string, embedder.tokenizer
)
get_embedding_for_tkn = partial(
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280 token_dim = 1280
else: # using LDM's BERT encoder else: # using LDM's BERT encoder
self.is_clip = False self.is_clip = False
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) get_token_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn
)
get_embedding_for_tkn = embedder.transformer.token_emb get_embedding_for_tkn = embedder.transformer.token_emb
token_dim = 1280 token_dim = 1280
@ -78,12 +101,31 @@ class EmbeddingManager(nn.Module):
init_word_token = get_token_for_string(initializer_words[idx]) init_word_token = get_token_for_string(initializer_words[idx])
with torch.no_grad(): with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) init_word_embedding = get_embedding_for_tkn(
init_word_token.cpu()
)
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) token_params = torch.nn.Parameter(
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=True,
)
self.initial_embeddings[
placeholder_string
] = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=False,
)
else: else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) token_params = torch.nn.Parameter(
torch.rand(
size=(num_vectors_per_token, token_dim),
requires_grad=True,
)
)
self.string_to_token_dict[placeholder_string] = token self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params self.string_to_param_dict[placeholder_string] = token_params
@ -95,36 +137,69 @@ class EmbeddingManager(nn.Module):
): ):
b, n, device = *tokenized_text.shape, tokenized_text.device b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items(): for (
placeholder_string,
placeholder_token,
) in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) placeholder_embedding = self.string_to_param_dict[
placeholder_string
].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement if (
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) self.max_vectors_per_token == 1
): # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(
tokenized_text == placeholder_token.to(device)
)
embedded_text[placeholder_idx] = placeholder_embedding embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words: if self.progressive_words:
self.progressive_counter += 1 self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE max_step_tokens = (
1 + self.progressive_counter // PROGRESSIVE_SCALE
)
else: else:
max_step_tokens = self.max_vectors_per_token max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) num_vectors_for_token = min(
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if placeholder_rows.nelement() == 0: if placeholder_rows.nelement() == 0:
continue continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) sorted_cols, sort_idx = torch.sort(
placeholder_cols, descending=True
)
sorted_rows = placeholder_rows[sort_idx] sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)): for idx in range(len(sorted_rows)):
row = sorted_rows[idx] row = sorted_rows[idx]
col = sorted_cols[idx] col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] new_token_row = torch.cat(
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] [
tokenized_text[row][:col],
placeholder_token.repeat(num_vectors_for_token).to(
device
),
tokenized_text[row][col + 1 :],
],
axis=0,
)[:n]
new_embed_row = torch.cat(
[
embedded_text[row][:col],
placeholder_embedding[:num_vectors_for_token],
embedded_text[row][col + 1 :],
],
axis=0,
)[:n]
embedded_text[row] = new_embed_row embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row tokenized_text[row] = new_token_row
@ -132,18 +207,27 @@ class EmbeddingManager(nn.Module):
return embedded_text return embedded_text
def save(self, ckpt_path): def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict, torch.save(
"string_to_param": self.string_to_param_dict}, ckpt_path) {
'string_to_token': self.string_to_token_dict,
'string_to_param': self.string_to_param_dict,
},
ckpt_path,
)
def load(self, ckpt_path): def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu') ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt["string_to_token"] self.string_to_token_dict = ckpt['string_to_token']
self.string_to_param_dict = ckpt["string_to_param"] self.string_to_param_dict = ckpt['string_to_param']
def get_embedding_norms_squared(self): def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim all_params = torch.cat(
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders list(self.string_to_param_dict.values()), axis=0
) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(
axis=-1
) # num_placeholders
return param_norm_squared return param_norm_squared
@ -152,13 +236,18 @@ class EmbeddingManager(nn.Module):
def embedding_to_coarse_loss(self): def embedding_to_coarse_loss(self):
loss = 0. loss = 0.0
num_embeddings = len(self.initial_embeddings) num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings: for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key] optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device) coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings loss = (
loss
+ (optimized - coarse)
@ (optimized - coarse).T
/ num_embeddings
)
return loss return loss

View File

@ -6,7 +6,11 @@ from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import kornia import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len=None): def _expand_mask(mask, dtype, tgt_len=None):
""" """
@ -15,11 +19,16 @@ def _expand_mask(mask, dtype, tgt_len = None):
bsz, src_len = mask.size() bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype): def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
@ -30,6 +39,7 @@ def _build_causal_attention_mask(bsz, seq_len, dtype):
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -38,7 +48,6 @@ class AbstractEncoder(nn.Module):
raise NotImplementedError raise NotImplementedError
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'): def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__() super().__init__()
@ -56,11 +65,17 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder): class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers""" """Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(
self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda'
):
super().__init__() super().__init__()
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(
attn_layers=Encoder(dim=n_embed, depth=n_layer)) num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens): def forward(self, tokens):
tokens = tokens.to(self.device) # meh tokens = tokens.to(self.device) # meh
@ -73,26 +88,41 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder): class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__() super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements from transformers import (
BertTokenizerFast,
) # TODO: add to reuquirements
# Modified to allow to run on non-internet connected compute nodes. # Modified to allow to run on non-internet connected compute nodes.
# Model needs to be loaded into cache from an internet-connected machine # Model needs to be loaded into cache from an internet-connected machine
# by running: # by running:
# from transformers import BertTokenizerFast # from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased") # BertTokenizerFast.from_pretrained("bert-base-uncased")
try: try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True) self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True
)
except OSError: except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.") raise SystemExit(
"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine."
)
self.device = device self.device = device
self.vq_interface = vq_interface self.vq_interface = vq_interface
self.max_length = max_length self.max_length = max_length
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
return tokens return tokens
@torch.no_grad() @torch.no_grad()
@ -108,53 +138,84 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder): class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers""" """Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0): def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__() super().__init__()
self.use_tknz_fn = use_tokenizer self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn: if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len
)
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer), attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout) emb_dropout=embedding_dropout,
)
def forward(self, text, embedding_manager=None): def forward(self, text, embedding_manager=None):
if self.use_tknz_fn: if self.use_tknz_fn:
tokens = self.tknz_fn(text) # .to(self.device) tokens = self.tknz_fn(text) # .to(self.device)
else: else:
tokens = text tokens = text
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager) z = self.transformer(
tokens, return_embeddings=True, embedding_manager=embedding_manager
)
return z return z
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
# output of length 77 # output of length 77
return self(text, **kwargs) return self(text, **kwargs)
class SpatialRescaler(nn.Module): class SpatialRescaler(nn.Module):
def __init__(self, def __init__(
self,
n_stages=1, n_stages=1,
method='bilinear', method='bilinear',
multiplier=0.5, multiplier=0.5,
in_channels=3, in_channels=3,
out_channels=None, out_channels=None,
bias=False): bias=False,
):
super().__init__() super().__init__()
self.n_stages = n_stages self.n_stages = n_stages
assert self.n_stages >= 0 assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] assert method in [
'nearest',
'linear',
'bilinear',
'trilinear',
'bicubic',
'area',
]
self.multiplier = multiplier self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.interpolator = partial(
torch.nn.functional.interpolate, mode=method
)
self.remap_output = out_channels is not None self.remap_output = out_channels is not None
if self.remap_output: if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') print(
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias
)
def forward(self, x): def forward(self, x):
for stage in range(self.n_stages): for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier) x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output: if self.remap_output:
x = self.channel_mapper(x) x = self.channel_mapper(x)
return x return x
@ -162,12 +223,23 @@ class SpatialRescaler(nn.Module):
def encode(self, x): def encode(self, x):
return self(x) return self(x)
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77,
):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=True) self.tokenizer = CLIPTokenizer.from_pretrained(
self.transformer = CLIPTextModel.from_pretrained(version,local_files_only=True) version, local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True
)
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.freeze() self.freeze()
@ -180,7 +252,11 @@ class FrozenCLIPEmbedder(AbstractEncoder):
embedding_manager=None, embedding_manager=None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, :seq_length]
@ -191,13 +267,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
if embedding_manager is not None: if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds) inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids) position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings embeddings = inputs_embeds + position_embeddings
return embeddings return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) self.transformer.text_model.embeddings.forward = (
embedding_forward.__get__(self.transformer.text_model.embeddings)
)
def encoder_forward( def encoder_forward(
self, self,
@ -208,11 +285,21 @@ class FrozenCLIPEmbedder(AbstractEncoder):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
@ -239,8 +326,9 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return hidden_states return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward( def text_encoder_forward(
self, self,
@ -252,31 +340,47 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return_dict=None, return_dict=None,
embedding_manager=None, embedding_manager=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None: if input_ids is None:
raise ValueError("You have to specify either input_ids") raise ValueError('You have to specify either input_ids')
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( causal_attention_mask = _build_causal_attention_mask(
hidden_states.device bsz, seq_len, hidden_states.dtype
) ).to(hidden_states.device)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _expand_mask(
attention_mask, hidden_states.dtype
)
last_hidden_state = self.encoder( last_hidden_state = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
@ -291,7 +395,9 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return last_hidden_state return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward( def transformer_forward(
self, self,
@ -310,11 +416,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
embedding_manager = embedding_manager embedding_manager=embedding_manager,
) )
self.transformer.forward = transformer_forward.__get__(self.transformer) self.transformer.forward = transformer_forward.__get__(
self.transformer
)
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
@ -322,9 +429,16 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False param.requires_grad = False
def forward(self, text, **kwargs): def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
return z return z
@ -337,9 +451,17 @@ class FrozenCLIPTextEmbedder(nn.Module):
""" """
Uses the CLIP transformer encoder for text. Uses the CLIP transformer encoder for text.
""" """
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(
self,
version='ViT-L/14',
device='cuda',
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__() super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu") self.model, _ = clip.load(version, jit=False, device='cpu')
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.n_repeat = n_repeat self.n_repeat = n_repeat
@ -369,6 +491,7 @@ class FrozenClipImageEmbedder(nn.Module):
""" """
Uses the CLIP image encoder. Uses the CLIP image encoder.
""" """
def __init__( def __init__(
self, self,
model, model,
@ -381,15 +504,27 @@ class FrozenClipImageEmbedder(nn.Module):
self.antialias = antialias self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer(
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False,
)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False,
)
def preprocess(self, x): def preprocess(self, x):
# normalize to [0,1] # normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224), x = kornia.geometry.resize(
interpolation='bicubic',align_corners=True, x,
antialias=self.antialias) (224, 224),
x = (x + 1.) / 2. interpolation='bicubic',
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip # renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std) x = kornia.enhance.normalize(x, self.mean, self.std)
return x return x
@ -399,7 +534,8 @@ class FrozenClipImageEmbedder(nn.Module):
return self.model.encode_image(self.preprocess(x)) return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__": if __name__ == '__main__':
from ldm.util import count_params from ldm.util import count_params
model = FrozenCLIPEmbedder() model = FrozenCLIPEmbedder()
count_params(model, verbose=True) count_params(model, verbose=True)

View File

@ -1,2 +1,6 @@
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr from ldm.modules.image_degradation.bsrgan import (
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -27,13 +27,13 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf): def modcrop_np(img, sf):
''' """
Args: Args:
img: numpy image, WxH or WxHxC img: numpy image, WxH or WxHxC
sf: scale factor sf: scale factor
Return: Return:
cropped image cropped image
''' """
w, h = img.shape[:2] w, h = img.shape[:2]
im = np.copy(img) im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...] return im[: w - w % sf, : h - h % sf, ...]
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one # Loop over the small kernel to fill the big one
for r in range(k_size): for r in range(k_size):
for c in range(k_size): for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR # Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2 crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop] cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel k : kernel
""" """
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]]) V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]]) D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,23 +133,31 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k): def blur(x, k):
''' """
x: image, NxcxHxW x: image, NxcxHxW
k: kernel, Nx1xhxw k: kernel, Nx1xhxw
''' """
n, c = x.shape[:2] n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1) k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3]) k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3]) x = x.view(n, c, x.shape[2], x.shape[3])
return x return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" " """ "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang # Kai Zhang
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta # Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2]) LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], Q = np.array(
[np.sin(theta), np.cos(theta)]]) [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image) # Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None] MU = MU[None, None, :, None]
# Create meshgrid for Gaussian # Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize] hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) [x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std) arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg) h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0 h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs): def fspecial(filter_type, *args, **kwargs):
''' """
python code from: python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
''' """
if filter_type == 'gaussian': if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs) return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian': if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3): def bicubic_degradation(x, sf=3):
''' """
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
sf: down-scale factor sf: down-scale factor
Return: Return:
bicubicly downsampled LR image bicubicly downsampled LR image
''' """
x = util.imresize_np(x, scale=1 / sf) x = util.imresize_np(x, scale=1 / sf)
return x return x
def srmd_degradation(x, k, sf=3): def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling """blur + bicubic downsampling
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271}, pages={3262--3271},
year={2018} year={2018}
} }
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
return x return x
def dpsr_degradation(x, k, sf=3): def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur """bicubic downsampling + blur
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681}, pages={1671--1681},
year={2019} year={2019}
} }
''' """
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x return x
def classical_degradation(x, k, sf=3): def classical_degradation(x, k, sf=3):
''' blur + downsampling """blur + downsampling
Args: Args:
x: HxWxC image, [0, 1]/[0, 255] x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double k: hxw, double
sf: down-scale factor sf: down-scale factor
Return: Return:
downsampled LR image downsampled LR image
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0 st = 0
@ -328,10 +350,19 @@ def add_blur(img, sf=4):
if random.random() < 0.5: if random.random() < 0.5:
l1 = wd2 * random.random() l1 = wd2 * random.random()
l2 = wd2 * random.random() l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else: else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) k = fspecial(
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img return img
@ -344,7 +375,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1) sf1 = random.uniform(0.5 / sf, 1)
else: else:
sf1 = 1.0 sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -366,19 +401,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0) # img = np.clip(img, 0.0, 1.0)
# return img # return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2) noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand() rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise else: # add noise
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -388,28 +430,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
rnum = random.random() rnum = random.random()
if rnum > 0.6: if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4: elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: else:
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
def add_Poisson_noise(img): def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255. img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4] vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5: if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals img = np.random.poisson(img * vals).astype(np.float32) / vals
else: else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis] img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -418,7 +469,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img): def add_JPEG_noise(img):
quality_factor = random.randint(30, 95) quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1) img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img return img
@ -431,7 +484,11 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq return lq, hq
@ -462,8 +519,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
img = util.imresize_np(img, 1 / 2, True) img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
@ -472,7 +532,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -487,19 +550,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.75: if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -551,8 +625,11 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
image = util.imresize_np(image, 1 / 2, True) image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
@ -561,7 +638,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -576,19 +656,33 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.75: if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -609,12 +703,19 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise # add final JPEG compression noise
image = add_JPEG_noise(image) image = add_JPEG_noise(image)
image = util.single2uint(image) image = util.single2uint(image)
example = {"image":image} example = {'image': image}
return example return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): def degradation_bsrgan_plus(
img,
sf=4,
shuffle_prob=0.5,
use_sharp=True,
lq_patchsize=64,
isp_model=None,
):
""" """
This is an extended degradation model by combining This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN the degradation models of BSRGAN and Real-ESRGAN
@ -645,8 +746,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
else: else:
shuffle_order = list(range(13)) shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one # local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) shuffle_order[2:6] = random.sample(
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) shuffle_order[2:6], len(range(2, 6))
)
shuffle_order[9:13] = random.sample(
shuffle_order[9:13], len(range(9, 13))
)
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
@ -689,8 +794,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
print('check the shuffle!') print('check the shuffle!')
# resize to desired size # resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise # add final JPEG compression noise
img = add_JPEG_noise(img) img = add_JPEG_noise(img)
@ -702,29 +810,37 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
if __name__ == '__main__': if __name__ == '__main__':
print("hey") print('hey')
img = util.imread_uint('utils/test.png', 3) img = util.imread_uint('utils/test.png', 3)
print(img) print(img)
img = util.uint2single(img) img = util.uint2single(img)
print(img) print(img)
img = img[:448, :448] img = img[:448, :448]
h = img.shape[0] // 4 h = img.shape[0] // 4
print("resizing to", h) print('resizing to', h)
sf = 4 sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf) deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20): for i in range(20):
print(i) print(i)
img_lq = deg_fn(img) img_lq = deg_fn(img)
print(img_lq) print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img)['image']
print(img_lq.shape) print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape) print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape) print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), lq_nearest = cv2.resize(
interpolation=0) util.single2uint(img_lq),
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0) interpolation=0,
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) )
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png') util.imsave(img_concat, str(i) + '.png')

View File

@ -27,13 +27,13 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf): def modcrop_np(img, sf):
''' """
Args: Args:
img: numpy image, WxH or WxHxC img: numpy image, WxH or WxHxC
sf: scale factor sf: scale factor
Return: Return:
cropped image cropped image
''' """
w, h = img.shape[:2] w, h = img.shape[:2]
im = np.copy(img) im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...] return im[: w - w % sf, : h - h % sf, ...]
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one # Loop over the small kernel to fill the big one
for r in range(k_size): for r in range(k_size):
for c in range(k_size): for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR # Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2 crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop] cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel k : kernel
""" """
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]]) V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]]) D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,23 +133,31 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k): def blur(x, k):
''' """
x: image, NxcxHxW x: image, NxcxHxW
k: kernel, Nx1xhxw k: kernel, Nx1xhxw
''' """
n, c = x.shape[:2] n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1) k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3]) k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3]) x = x.view(n, c, x.shape[2], x.shape[3])
return x return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" " """ "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang # Kai Zhang
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta # Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2]) LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], Q = np.array(
[np.sin(theta), np.cos(theta)]]) [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image) # Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None] MU = MU[None, None, :, None]
# Create meshgrid for Gaussian # Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize] hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) [x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std) arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg) h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0 h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs): def fspecial(filter_type, *args, **kwargs):
''' """
python code from: python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
''' """
if filter_type == 'gaussian': if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs) return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian': if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3): def bicubic_degradation(x, sf=3):
''' """
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
sf: down-scale factor sf: down-scale factor
Return: Return:
bicubicly downsampled LR image bicubicly downsampled LR image
''' """
x = util.imresize_np(x, scale=1 / sf) x = util.imresize_np(x, scale=1 / sf)
return x return x
def srmd_degradation(x, k, sf=3): def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling """blur + bicubic downsampling
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271}, pages={3262--3271},
year={2018} year={2018}
} }
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
return x return x
def dpsr_degradation(x, k, sf=3): def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur """bicubic downsampling + blur
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681}, pages={1671--1681},
year={2019} year={2019}
} }
''' """
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x return x
def classical_degradation(x, k, sf=3): def classical_degradation(x, k, sf=3):
''' blur + downsampling """blur + downsampling
Args: Args:
x: HxWxC image, [0, 1]/[0, 255] x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double k: hxw, double
sf: down-scale factor sf: down-scale factor
Return: Return:
downsampled LR image downsampled LR image
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0 st = 0
@ -332,10 +354,19 @@ def add_blur(img, sf=4):
if random.random() < 0.5: if random.random() < 0.5:
l1 = wd2 * random.random() l1 = wd2 * random.random()
l2 = wd2 * random.random() l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else: else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) k = fspecial(
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 'gaussian', random.randint(2, 4) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img return img
@ -348,7 +379,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1) sf1 = random.uniform(0.5 / sf, 1)
else: else:
sf1 = 1.0 sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -370,19 +405,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0) # img = np.clip(img, 0.0, 1.0)
# return img # return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2) noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand() rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise else: # add noise
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -392,28 +434,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
rnum = random.random() rnum = random.random()
if rnum > 0.6: if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4: elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: else:
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
def add_Poisson_noise(img): def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255. img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4] vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5: if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals img = np.random.poisson(img * vals).astype(np.float32) / vals
else: else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis] img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -422,7 +473,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img): def add_JPEG_noise(img):
quality_factor = random.randint(80, 95) quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1) img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img return img
@ -435,7 +488,11 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq return lq, hq
@ -466,8 +523,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
img = util.imresize_np(img, 1 / 2, True) img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
@ -476,7 +536,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -491,19 +554,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.75: if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -555,8 +629,11 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
image = util.imresize_np(image, 1 / 2, True) image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
@ -565,7 +642,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -583,20 +663,34 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.8: if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -617,34 +711,41 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise # add final JPEG compression noise
image = add_JPEG_noise(image) image = add_JPEG_noise(image)
image = util.single2uint(image) image = util.single2uint(image)
example = {"image": image} example = {'image': image}
return example return example
if __name__ == '__main__': if __name__ == '__main__':
print("hey") print('hey')
img = util.imread_uint('utils/test.png', 3) img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448] img = img[:448, :448]
h = img.shape[0] // 4 h = img.shape[0] // 4
print("resizing to", h) print('resizing to', h)
sf = 4 sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf) deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20): for i in range(20):
print(i) print(i)
img_hq = img img_hq = img
img_lq = deg_fn(img)["image"] img_lq = deg_fn(img)['image']
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq) print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)['image']
print(img_lq.shape) print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape) print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape) print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), lq_nearest = cv2.resize(
interpolation=0) util.single2uint(img_lq),
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0) interpolation=0,
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) )
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png') util.imsave(img_concat, str(i) + '.png')

View File

@ -6,13 +6,14 @@ import torch
import cv2 import cv2
from torchvision.utils import make_grid from torchvision.utils import make_grid
from datetime import datetime from datetime import datetime
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py # import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
''' """
# -------------------------------------------- # --------------------------------------------
# Kai Zhang (github: https://github.com/cszn) # Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019 # 03/Mar/2019
@ -20,10 +21,22 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# https://github.com/twhui/SRGAN-pyTorch # https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR # https://github.com/xinntao/BasicSR
# -------------------------------------------- # --------------------------------------------
''' """
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
'.tif',
]
def is_image_file(filename): def is_image_file(filename):
@ -57,11 +70,11 @@ def surf(Z, cmap='rainbow', figsize=None):
plt.show() plt.show()
''' """
# -------------------------------------------- # --------------------------------------------
# get image pathes # get image pathes
# -------------------------------------------- # --------------------------------------------
''' """
def get_image_paths(dataroot): def get_image_paths(dataroot):
@ -83,11 +96,11 @@ def _get_paths_from_images(path):
return images return images
''' """
# -------------------------------------------- # --------------------------------------------
# split large images into small images # split large images into small images
# -------------------------------------------- # --------------------------------------------
''' """
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
@ -118,11 +131,21 @@ def imssave(imgs, img_path):
for i, img in enumerate(imgs): for i, img in enumerate(imgs):
if img.ndim == 3: if img.ndim == 3:
img = img[:, :, [2, 1, 0]] img = img[:, :, [2, 1, 0]]
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') new_path = os.path.join(
os.path.dirname(img_path),
img_name + str('_s{:04d}'.format(i)) + '.png',
)
cv2.imwrite(new_path, img) cv2.imwrite(new_path, img)
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
""" """
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
@ -139,15 +162,18 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800,
# img_name, ext = os.path.splitext(os.path.basename(img_path)) # img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels) img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max) patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) imssave(
patches, os.path.join(taget_dataroot, os.path.basename(img_path))
)
# if original_dataroot == taget_dataroot: # if original_dataroot == taget_dataroot:
# del img_path # del img_path
'''
"""
# -------------------------------------------- # --------------------------------------------
# makedir # makedir
# -------------------------------------------- # --------------------------------------------
''' """
def mkdir(path): def mkdir(path):
@ -171,12 +197,12 @@ def mkdir_and_rename(path):
os.makedirs(path) os.makedirs(path)
''' """
# -------------------------------------------- # --------------------------------------------
# read image from path # read image from path
# opencv is fast, but read BGR numpy image # opencv is fast, but read BGR numpy image
# -------------------------------------------- # --------------------------------------------
''' """
# -------------------------------------------- # --------------------------------------------
@ -206,6 +232,7 @@ def imsave(img, img_path):
img = img[:, :, [2, 1, 0]] img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
def imwrite(img, img_path): def imwrite(img, img_path):
img = np.squeeze(img) img = np.squeeze(img)
if img.ndim == 3: if img.ndim == 3:
@ -213,7 +240,6 @@ def imwrite(img, img_path):
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
# -------------------------------------------- # --------------------------------------------
# get single image of size HxWxn_channles (BGR) # get single image of size HxWxn_channles (BGR)
# -------------------------------------------- # --------------------------------------------
@ -221,7 +247,7 @@ def read_img(path):
# read image by cv2 # read image by cv2
# return: Numpy float32, HWC, BGR, [0,1] # return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255. img = img.astype(np.float32) / 255.0
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
# some images have 4 channels # some images have 4 channels
@ -230,7 +256,7 @@ def read_img(path):
return img return img
''' """
# -------------------------------------------- # --------------------------------------------
# image format conversion # image format conversion
# -------------------------------------------- # --------------------------------------------
@ -238,7 +264,7 @@ def read_img(path):
# numpy(single) <---> tensor # numpy(single) <---> tensor
# numpy(unit) <---> tensor # numpy(unit) <---> tensor
# -------------------------------------------- # --------------------------------------------
''' """
# -------------------------------------------- # --------------------------------------------
@ -248,22 +274,22 @@ def read_img(path):
def uint2single(img): def uint2single(img):
return np.float32(img/255.) return np.float32(img / 255.0)
def single2uint(img): def single2uint(img):
return np.uint8((img.clip(0, 1)*255.).round()) return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img): def uint162single(img):
return np.float32(img/65535.) return np.float32(img / 65535.0)
def single2uint16(img): def single2uint16(img):
return np.uint16((img.clip(0, 1)*65535.).round()) return np.uint16((img.clip(0, 1) * 65535.0).round())
# -------------------------------------------- # --------------------------------------------
@ -275,14 +301,25 @@ def single2uint16(img):
def uint2tensor4(img): def uint2tensor4(img):
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
# convert uint to 3-dimensional torch tensor # convert uint to 3-dimensional torch tensor
def uint2tensor3(img): def uint2tensor3(img):
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
)
# convert 2/3/4-dimensional torch tensor to uint # convert 2/3/4-dimensional torch tensor to uint
@ -305,7 +342,12 @@ def single2tensor3(img):
# convert single (HxWxC) to 4-dimensional torch tensor # convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img): def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
# convert torch tensor to single # convert torch tensor to single
@ -316,6 +358,7 @@ def tensor2single(img):
return img return img
# convert torch tensor to single # convert torch tensor to single
def tensor2single3(img): def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy() img = img.data.squeeze().float().cpu().numpy()
@ -327,30 +370,48 @@ def tensor2single3(img):
def single2tensor5(img): def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
def single32tensor5(img): def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.float()
.unsqueeze(0)
.unsqueeze(0)
)
def single42tensor4(img): def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
)
# from skimage.io import imread, imsave # from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
''' """
Converts a torch Tensor into an image Numpy array of BGR channel order Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
''' """
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (
min_max[1] - min_max[0]
) # to range [0,1]
n_dim = tensor.dim() n_dim = tensor.dim()
if n_dim == 4: if n_dim == 4:
n_img = len(tensor) n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() img_np = make_grid(
tensor, nrow=int(math.sqrt(n_img)), normalize=False
).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3: elif n_dim == 3:
img_np = tensor.numpy() img_np = tensor.numpy()
@ -359,14 +420,17 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
img_np = tensor.numpy() img_np = tensor.numpy()
else: else:
raise TypeError( raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(
n_dim
)
)
if out_type == np.uint8: if out_type == np.uint8:
img_np = (img_np * 255.0).round() img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default. # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type) return img_np.astype(out_type)
''' """
# -------------------------------------------- # --------------------------------------------
# Augmentation, flipe and/or rotate # Augmentation, flipe and/or rotate
# -------------------------------------------- # --------------------------------------------
@ -374,12 +438,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
# (1) augmet_img: numpy image of WxHxC or WxH # (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH # (2) augment_img_tensor4: tensor image 1xCxWxH
# -------------------------------------------- # --------------------------------------------
''' """
def augment_img(img, mode=0): def augment_img(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn) """Kai Zhang (github: https://github.com/cszn)"""
'''
if mode == 0: if mode == 0:
return img return img
elif mode == 1: elif mode == 1:
@ -399,8 +462,7 @@ def augment_img(img, mode=0):
def augment_img_tensor4(img, mode=0): def augment_img_tensor4(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn) """Kai Zhang (github: https://github.com/cszn)"""
'''
if mode == 0: if mode == 0:
return img return img
elif mode == 1: elif mode == 1:
@ -420,8 +482,7 @@ def augment_img_tensor4(img, mode=0):
def augment_img_tensor(img, mode=0): def augment_img_tensor(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn) """Kai Zhang (github: https://github.com/cszn)"""
'''
img_size = img.size() img_size = img.size()
img_np = img.data.cpu().numpy() img_np = img.data.cpu().numpy()
if len(img_size) == 3: if len(img_size) == 3:
@ -484,11 +545,11 @@ def augment_imgs(img_list, hflip=True, rot=True):
return [_augment(img) for img in img_list] return [_augment(img) for img in img_list]
''' """
# -------------------------------------------- # --------------------------------------------
# modcrop and shave # modcrop and shave
# -------------------------------------------- # --------------------------------------------
''' """
def modcrop(img_in, scale): def modcrop(img_in, scale):
@ -515,7 +576,7 @@ def shave(img_in, border=0):
return img return img
''' """
# -------------------------------------------- # --------------------------------------------
# image processing process on numpy image # image processing process on numpy image
# channel_convert(in_c, tar_type, img_list): # channel_convert(in_c, tar_type, img_list):
@ -523,74 +584,92 @@ def shave(img_in, border=0):
# bgr2ycbcr(img, only_y=True): # bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img): # ycbcr2rgb(img):
# -------------------------------------------- # --------------------------------------------
''' """
def rgb2ycbcr(img, only_y=True): def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr """same as matlab rgb2ycbcr
only_y: only return Y channel only_y: only return Y channel
Input: Input:
uint8, [0, 255] uint8, [0, 255]
float, [0, 1] float, [0, 1]
''' """
in_img_type = img.dtype in_img_type = img.dtype
img.astype(np.float32) img.astype(np.float32)
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255. img *= 255.0
# convert # convert
if only_y: if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else: else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], rlt = np.matmul(
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
rlt /= 255. rlt /= 255.0
return rlt.astype(in_img_type) return rlt.astype(in_img_type)
def ycbcr2rgb(img): def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb """same as matlab ycbcr2rgb
Input: Input:
uint8, [0, 255] uint8, [0, 255]
float, [0, 1] float, [0, 1]
''' """
in_img_type = img.dtype in_img_type = img.dtype
img.astype(np.float32) img.astype(np.float32)
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255. img *= 255.0
# convert # convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], rlt = np.matmul(
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
rlt /= 255. rlt /= 255.0
return rlt.astype(in_img_type) return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True): def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr """bgr version of rgb2ycbcr
only_y: only return Y channel only_y: only return Y channel
Input: Input:
uint8, [0, 255] uint8, [0, 255]
float, [0, 1] float, [0, 1]
''' """
in_img_type = img.dtype in_img_type = img.dtype
img.astype(np.float32) img.astype(np.float32)
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255. img *= 255.0
# convert # convert
if only_y: if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else: else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], rlt = np.matmul(
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
rlt /= 255. rlt /= 255.0
return rlt.astype(in_img_type) return rlt.astype(in_img_type)
@ -608,11 +687,11 @@ def channel_convert(in_c, tar_type, img_list):
return img_list return img_list
''' """
# -------------------------------------------- # --------------------------------------------
# metric, PSNR and SSIM # metric, PSNR and SSIM
# -------------------------------------------- # --------------------------------------------
''' """
# -------------------------------------------- # --------------------------------------------
@ -640,10 +719,10 @@ def calculate_psnr(img1, img2, border=0):
# SSIM # SSIM
# -------------------------------------------- # --------------------------------------------
def calculate_ssim(img1, img2, border=0): def calculate_ssim(img1, img2, border=0):
'''calculate SSIM """calculate SSIM
the same outputs as MATLAB's the same outputs as MATLAB's
img1, img2: [0, 255] img1, img2: [0, 255]
''' """
# img1 = img1.squeeze() # img1 = img1.squeeze()
# img2 = img2.squeeze() # img2 = img2.squeeze()
if not img1.shape == img2.shape: if not img1.shape == img2.shape:
@ -684,16 +763,17 @@ def ssim(img1, img2):
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(sigma1_sq + sigma2_sq + C2)) (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
return ssim_map.mean() return ssim_map.mean()
''' """
# -------------------------------------------- # --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1] # matlab's bicubic imresize (numpy and torch) [0, 1]
# -------------------------------------------- # --------------------------------------------
''' """
# matlab 'imresize' function, now only support 'bicubic' # matlab 'imresize' function, now only support 'bicubic'
@ -701,11 +781,14 @@ def cubic(x):
absx = torch.abs(x) absx = torch.abs(x)
absx2 = absx**2 absx2 = absx**2
absx3 = absx**3 absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
if (scale < 1) and (antialiasing): if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale kernel_width = kernel_width / scale
@ -729,8 +812,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width
# The indices of the input pixels involved in computing the k-th output # The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix. # pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
1, P).expand(out_length, P) 0, P - 1, P
).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the # The weights used to compute the k-th output pixel are in row k of the
# weights matrix. # weights matrix.
@ -771,7 +855,11 @@ def imresize(img, scale, antialiasing=True):
if need_squeeze: if need_squeeze:
img.unsqueeze_(0) img.unsqueeze_(0)
in_C, in_H, in_W = img.size() in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4 kernel_width = 4
kernel = 'cubic' kernel = 'cubic'
@ -782,9 +870,11 @@ def imresize(img, scale, antialiasing=True):
# get weights and indices # get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing) in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing) in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension # process H dimension
# symmetric copying # symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
@ -805,7 +895,11 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_H): for i in range(out_H):
idx = int(indices_H[i][0]) idx = int(indices_H[i][0])
for j in range(out_C): for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension # process W dimension
# symmetric copying # symmetric copying
@ -827,7 +921,9 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_W): for i in range(out_W):
idx = int(indices_W[i][0]) idx = int(indices_W[i][0])
for j in range(out_C): for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(
weights_W[i]
)
if need_squeeze: if need_squeeze:
out_2.squeeze_() out_2.squeeze_()
return out_2 return out_2
@ -846,7 +942,11 @@ def imresize_np(img, scale, antialiasing=True):
img.unsqueeze_(2) img.unsqueeze_(2)
in_H, in_W, in_C = img.size() in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4 kernel_width = 4
kernel = 'cubic' kernel = 'cubic'
@ -857,9 +957,11 @@ def imresize_np(img, scale, antialiasing=True):
# get weights and indices # get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing) in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing) in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension # process H dimension
# symmetric copying # symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
@ -880,7 +982,11 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_H): for i in range(out_H):
idx = int(indices_H[i][0]) idx = int(indices_H[i][0])
for j in range(out_C): for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension # process W dimension
# symmetric copying # symmetric copying
@ -902,7 +1008,9 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_W): for i in range(out_W):
idx = int(indices_W[i][0]) idx = int(indices_W[i][0])
for j in range(out_C): for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(
weights_W[i]
)
if need_squeeze: if need_squeeze:
out_2.squeeze_() out_2.squeeze_()

View File

@ -5,13 +5,24 @@ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/
class LPIPSWithDiscriminator(nn.Module): class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, def __init__(
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, self,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_start,
disc_loss="hinge"): logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss='hinge',
):
super().__init__() super().__init__()
assert disc_loss in ["hinge", "vanilla"] assert disc_loss in ['hinge', 'vanilla']
self.kl_weight = kl_weight self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval() self.perceptual_loss = LPIPS().eval()
@ -19,42 +30,68 @@ class LPIPSWithDiscriminator(nn.Module):
# output log variance # output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers, n_layers=disc_num_layers,
use_actnorm=use_actnorm use_actnorm=use_actnorm,
).apply(weights_init) ).apply(weights_init)
self.discriminator_iter_start = disc_start self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_loss = (
hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss
)
self.disc_factor = disc_factor self.disc_factor = disc_factor
self.discriminator_weight = disc_weight self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None: if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else: else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight d_weight = d_weight * self.discriminator_weight
return d_weight return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx, def forward(
global_step, last_layer=None, cond=None, split="train", self,
weights=None): inputs,
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
weights=None,
):
rec_loss = torch.abs(
inputs.contiguous() - reconstructions.contiguous()
)
if self.perceptual_weight > 0: if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss weighted_nll_loss = nll_loss
if weights is not None: if weights is not None:
weighted_nll_loss = weights * nll_loss weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] weighted_nll_loss = (
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
)
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl() kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
@ -67,27 +104,42 @@ class LPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous()) logits_fake = self.discriminator(reconstructions.contiguous())
else: else:
assert self.disc_conditional assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake) g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0: if self.disc_factor > 0.0:
try: try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError: except RuntimeError:
assert not self.training assert not self.training
d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0)
else: else:
d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), log = {
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), '{}/total_loss'.format(split): loss.clone().detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(), '{}/logvar'.format(split): self.logvar.detach(),
"{}/d_weight".format(split): d_weight.detach(), '{}/kl_loss'.format(split): kl_loss.detach().mean(),
"{}/disc_factor".format(split): torch.tensor(disc_factor), '{}/nll_loss'.format(split): nll_loss.detach().mean(),
"{}/g_loss".format(split): g_loss.detach().mean(), '{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
} }
return loss, log return loss, log
@ -95,17 +147,29 @@ class LPIPSWithDiscriminator(nn.Module):
# second pass for discriminator update # second pass for discriminator update
if cond is None: if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach()) logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else: else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_real = self.discriminator(
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), log = {
"{}/logits_real".format(split): logits_real.detach().mean(), '{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean() '{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
} }
return d_loss, log return d_loss, log

View File

@ -3,21 +3,25 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import repeat from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from taming.modules.discriminator.model import (
NLayerDiscriminator,
weights_init,
)
from taming.modules.losses.lpips import LPIPS from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum() loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake) d_loss = 0.5 * (loss_real + loss_fake)
return d_loss return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold: if global_step < threshold:
weight = value weight = value
return weight return weight
@ -26,12 +30,15 @@ def adopt_weight(weight, global_step, threshold=0, value=0.):
def measure_perplexity(predicted_indices, n_embed): def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) encodings = (
F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
)
avg_probs = encodings.mean(0) avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0) cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use return perplexity, cluster_use
def l1(x, y): def l1(x, y):
return torch.abs(x - y) return torch.abs(x - y)
@ -41,42 +48,58 @@ def l2(x, y):
class VQLPIPSWithDiscriminator(nn.Module): class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, def __init__(
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, self,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_start,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", codebook_weight=1.0,
pixel_loss="l1"): pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss='hinge',
n_classes=None,
perceptual_loss='lpips',
pixel_loss='l1',
):
super().__init__() super().__init__()
assert disc_loss in ["hinge", "vanilla"] assert disc_loss in ['hinge', 'vanilla']
assert perceptual_loss in ["lpips", "clips", "dists"] assert perceptual_loss in ['lpips', 'clips', 'dists']
assert pixel_loss in ["l1", "l2"] assert pixel_loss in ['l1', 'l2']
self.codebook_weight = codebook_weight self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips": if perceptual_loss == 'lpips':
print(f"{self.__class__.__name__}: Running with LPIPS.") print(f'{self.__class__.__name__}: Running with LPIPS.')
self.perceptual_loss = LPIPS().eval() self.perceptual_loss = LPIPS().eval()
else: else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") raise ValueError(
f'Unknown perceptual loss: >> {perceptual_loss} <<'
)
self.perceptual_weight = perceptual_weight self.perceptual_weight = perceptual_weight
if pixel_loss == "l1": if pixel_loss == 'l1':
self.pixel_loss = l1 self.pixel_loss = l1
else: else:
self.pixel_loss = l2 self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers, n_layers=disc_num_layers,
use_actnorm=use_actnorm, use_actnorm=use_actnorm,
ndf=disc_ndf ndf=disc_ndf,
).apply(weights_init) ).apply(weights_init)
self.discriminator_iter_start = disc_start self.discriminator_iter_start = disc_start
if disc_loss == "hinge": if disc_loss == 'hinge':
self.disc_loss = hinge_d_loss self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla": elif disc_loss == 'vanilla':
self.disc_loss = vanilla_d_loss self.disc_loss = vanilla_d_loss
else: else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.") raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.')
self.disc_factor = disc_factor self.disc_factor = disc_factor
self.discriminator_weight = disc_weight self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional self.disc_conditional = disc_conditional
@ -84,25 +107,47 @@ class VQLPIPSWithDiscriminator(nn.Module):
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None: if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else: else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight d_weight = d_weight * self.discriminator_weight
return d_weight return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, def forward(
global_step, last_layer=None, cond=None, split="train", predicted_indices=None): self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
predicted_indices=None,
):
if not exists(codebook_loss): if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device) codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) rec_loss = self.pixel_loss(
inputs.contiguous(), reconstructions.contiguous()
)
if self.perceptual_weight > 0: if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss rec_loss = rec_loss + self.perceptual_weight * p_loss
else: else:
p_loss = torch.tensor([0.0]) p_loss = torch.tensor([0.0])
@ -119,49 +164,77 @@ class VQLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous()) logits_fake = self.discriminator(reconstructions.contiguous())
else: else:
assert self.disc_conditional assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake) g_loss = -torch.mean(logits_fake)
try: try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError: except RuntimeError:
assert not self.training assert not self.training
d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), log = {
"{}/quant_loss".format(split): codebook_loss.detach().mean(), '{}/total_loss'.format(split): loss.clone().detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(), '{}/quant_loss'.format(split): codebook_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(), '{}/nll_loss'.format(split): nll_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(), '{}/rec_loss'.format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(), '{}/p_loss'.format(split): p_loss.detach().mean(),
"{}/disc_factor".format(split): torch.tensor(disc_factor), '{}/d_weight'.format(split): d_weight.detach(),
"{}/g_loss".format(split): g_loss.detach().mean(), '{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
} }
if predicted_indices is not None: if predicted_indices is not None:
assert self.n_classes is not None assert self.n_classes is not None
with torch.no_grad(): with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) perplexity, cluster_usage = measure_perplexity(
log[f"{split}/perplexity"] = perplexity predicted_indices, self.n_classes
log[f"{split}/cluster_usage"] = cluster_usage )
log[f'{split}/perplexity'] = perplexity
log[f'{split}/cluster_usage'] = cluster_usage
return loss, log return loss, log
if optimizer_idx == 1: if optimizer_idx == 1:
# second pass for discriminator update # second pass for discriminator update
if cond is None: if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach()) logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else: else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_real = self.discriminator(
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), log = {
"{}/logits_real".format(split): logits_real.detach().mean(), '{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean() '{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
} }
return d_loss, log return d_loss, log

View File

@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [ Intermediates = namedtuple(
'pre_softmax_attn', 'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']
'post_softmax_attn' )
])
LayerIntermediates = namedtuple('Intermediates', [ LayerIntermediates = namedtuple(
'hiddens', 'Intermediates', ['hiddens', 'attn_intermediates']
'attn_intermediates' )
])
class AbsolutePositionalEmbedding(nn.Module): class AbsolutePositionalEmbedding(nn.Module):
@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0): def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(
self.inv_freq
)
+ offset
)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :] return emb[None, :, :]
@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers # helpers
def exists(val): def exists(val):
return val is not None return val is not None
@ -64,18 +68,21 @@ def default(val, d):
def always(val): def always(val):
def inner(*args, **kwargs): def inner(*args, **kwargs):
return val return val
return inner return inner
def not_equals(val): def not_equals(val):
def inner(x): def inner(x):
return x != val return x != val
return inner return inner
def equals(val): def equals(val):
def inner(x): def inner(x):
return x == val return x == val
return inner return inner
@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers # keyword argument helpers
def pick_and_pop(keys, d): def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys)) values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values)) return dict(zip(keys, values))
@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d): def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_with_prefix, kwargs = group_dict_by_key(
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(
lambda x: (x[0][len(prefix) :], x[1]),
tuple(kwargs_with_prefix.items()),
)
)
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
@ -173,7 +188,7 @@ class GRUGating(nn.Module):
def forward(self, x, residual): def forward(self, x, residual):
gated_output = self.gru( gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'), rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d') rearrange(residual, 'b n d -> (b n) d'),
) )
return gated_output.reshape_as(x) return gated_output.reshape_as(x)
@ -181,6 +196,7 @@ class GRUGating(nn.Module):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super().__init__() super().__init__()
@ -192,19 +208,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = (
nn.Linear(dim, inner_dim), nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
nn.GELU() if not glu
) if not glu else GEGLU(dim, inner_dim) else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
) )
def forward(self, x): def forward(self, x):
@ -224,12 +239,14 @@ class Attention(nn.Module):
sparse_topk=None, sparse_topk=None,
use_entmax15=False, use_entmax15=False,
num_mem_kv=0, num_mem_kv=0,
dropout=0., dropout=0.0,
on_attn=False on_attn=False,
): ):
super().__init__() super().__init__()
if use_entmax15: if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!") raise NotImplementedError(
'Check out entmax activation instead of softmax activation!'
)
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.causal = causal self.causal = causal
@ -263,7 +280,11 @@ class Attention(nn.Module):
# attention on attention # attention on attention
self.attn_on_attn = on_attn self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward( def forward(
self, self,
@ -274,9 +295,14 @@ class Attention(nn.Module):
rel_pos=None, rel_pos=None,
sinusoidal_emb=None, sinusoidal_emb=None,
prev_attn=None, prev_attn=None,
mem=None mem=None,
): ):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x) kv_input = default(context, x)
q_input = x q_input = x
@ -297,23 +323,35 @@ class Attention(nn.Module):
k = self.to_k(k_input) k = self.to_k(k_input)
v = self.to_v(v_input) v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)
)
input_mask = None input_mask = None
if any(map(exists, (mask, context_mask))): if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) q_mask = default(
mask, lambda: torch.ones((b, n), device=device).bool()
)
k_mask = q_mask if not exists(context) else context_mask k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) k_mask = default(
k_mask,
lambda: torch.ones((b, k.shape[-2]), device=device).bool(),
)
q_mask = rearrange(q_mask, 'b i -> b () i ()') q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j') k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask input_mask = q_mask * k_mask
if self.num_mem_kv > 0: if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) mem_k, mem_v = map(
lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v),
)
k = torch.cat((mem_k, k), dim=-2) k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2) v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask): if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True
)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots) mask_value = max_neg_value(dots)
@ -324,7 +362,9 @@ class Attention(nn.Module):
pre_softmax_attn = dots pre_softmax_attn = dots
if talking_heads: if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() dots = einsum(
'b h i j, h k -> b k i j', dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos): if exists(rel_pos):
dots = rel_pos(dots) dots = rel_pos(dots)
@ -336,7 +376,9 @@ class Attention(nn.Module):
if self.causal: if self.causal:
i, j = dots.shape[-2:] i, j = dots.shape[-2:]
r = torch.arange(i, device=device) r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j'
)
mask = F.pad(mask, (j - i, 0), value=False) mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value) dots.masked_fill_(mask, mask_value)
del mask del mask
@ -354,14 +396,16 @@ class Attention(nn.Module):
attn = self.dropout(attn) attn = self.dropout(attn)
if talking_heads: if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() attn = einsum(
'b h i j, h k -> b k i j', attn, self.post_softmax_proj
).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates( intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn, pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn post_softmax_attn=post_softmax_attn,
) )
return self.to_out(out), intermediates return self.to_out(out), intermediates
@ -390,7 +434,7 @@ class AttentionLayers(nn.Module):
macaron=False, macaron=False,
pre_norm=True, pre_norm=True,
gate_residual=False, gate_residual=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
@ -403,10 +447,14 @@ class AttentionLayers(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None) self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' assert (
rel_pos_num_buckets <= rel_pos_max_distance
), 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None self.rel_pos = None
self.pre_norm = pre_norm self.pre_norm = pre_norm
@ -438,15 +486,27 @@ class AttentionLayers(nn.Module):
assert 1 < par_ratio <= par_depth, 'par ratio out of range' assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block)) default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio' assert (
par_block = default_block + ('f',) * (par_width - len(default_block)) len(default_block) <= par_width
), 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (
par_width - len(default_block)
)
par_head = par_block * par_attn par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head)) layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef): elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' assert (
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef sandwich_coef > 0 and sandwich_coef <= depth
), 'sandwich coefficient should be less than the depth'
layer_types = (
('a',) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ('f',) * sandwich_coef
)
else: else:
layer_types = default_block * depth layer_types = default_block * depth
@ -455,7 +515,9 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types: for layer_type in self.layer_types:
if layer_type == 'a': if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs
)
elif layer_type == 'c': elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs) layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f': elif layer_type == 'f':
@ -472,11 +534,7 @@ class AttentionLayers(nn.Module):
else: else:
residual_fn = Residual() residual_fn = Residual()
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
norm_fn(),
layer,
residual_fn
]))
def forward( def forward(
self, self,
@ -486,7 +544,7 @@ class AttentionLayers(nn.Module):
context_mask=None, context_mask=None,
mems=None, mems=None,
return_hiddens=False, return_hiddens=False,
**kwargs **kwargs,
): ):
hiddens = [] hiddens = []
intermediates = [] intermediates = []
@ -495,7 +553,9 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1) is_last = ind == (len(self.layers) - 1)
if layer_type == 'a': if layer_type == 'a':
@ -508,10 +568,22 @@ class AttentionLayers(nn.Module):
x = norm(x) x = norm(x)
if layer_type == 'a': if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, out, inter = block(
prev_attn=prev_attn, mem=layer_mem) x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == 'c': elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == 'f': elif layer_type == 'f':
out = block(x) out = block(x)
@ -530,8 +602,7 @@ class AttentionLayers(nn.Module):
if return_hiddens: if return_hiddens:
intermediates = LayerIntermediates( intermediates = LayerIntermediates(
hiddens=hiddens, hiddens=hiddens, attn_intermediates=intermediates
attn_intermediates=intermediates
) )
return x, intermediates return x, intermediates
@ -545,7 +616,6 @@ class Encoder(AttentionLayers):
super().__init__(causal=False, **kwargs) super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module): class TransformerWrapper(nn.Module):
def __init__( def __init__(
self, self,
@ -554,14 +624,16 @@ class TransformerWrapper(nn.Module):
max_seq_len, max_seq_len,
attn_layers, attn_layers,
emb_dim=None, emb_dim=None,
max_mem_len=0., max_mem_len=0.0,
emb_dropout=0., emb_dropout=0.0,
num_memory_tokens=None, num_memory_tokens=None,
tie_embedding=False, tie_embedding=False,
use_pos_emb=True use_pos_emb=True,
): ):
super().__init__() super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim dim = attn_layers.dim
emb_dim = default(emb_dim, dim) emb_dim = default(emb_dim, dim)
@ -571,23 +643,34 @@ class TransformerWrapper(nn.Module):
self.num_tokens = num_tokens self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim) self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( self.pos_emb = (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.project_emb = (
nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
)
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.init_() self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper # memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0) num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0: if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim)
)
# let funnel encoder know number of memory tokens, if specified # let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'): if hasattr(attn_layers, 'num_memory_tokens'):
@ -605,7 +688,7 @@ class TransformerWrapper(nn.Module):
return_attn=False, return_attn=False,
mems=None, mems=None,
embedding_manager=None, embedding_manager=None,
**kwargs **kwargs,
): ):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
@ -629,7 +712,9 @@ class TransformerWrapper(nn.Module):
if exists(mask): if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True) mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x) x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:] mem, x = x[:, :num_mem], x[:, num_mem:]
@ -638,13 +723,30 @@ class TransformerWrapper(nn.Module):
if return_mems: if return_mems:
hiddens = intermediates.hiddens hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens new_mems = (
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) list(
map(
lambda pair: torch.cat(pair, dim=-2),
zip(mems, hiddens),
)
)
if exists(mems)
else hiddens
)
new_mems = list(
map(
lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems
)
)
return out, new_mems return out, new_mems
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = list(
map(
lambda t: t.post_softmax_attn,
intermediates.attn_intermediates,
)
)
return out, attn_maps return out, attn_maps
return out return out

View File

@ -113,17 +113,19 @@ class T2I:
The vast majority of these arguments default to reasonable values. The vast majority of these arguments default to reasonable values.
""" """
def __init__(self,
def __init__(
self,
batch_size=1, batch_size=1,
iterations=1, iterations=1,
steps=50, steps=50,
seed=None, seed=None,
cfg_scale=7.5, cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt", weights='models/ldm/stable-diffusion-v1/model.ckpt',
config = "configs/stable-diffusion/v1-inference.yaml", config='configs/stable-diffusion/v1-inference.yaml',
width=512, width=512,
height=512, height=512,
sampler_name="klms", sampler_name='klms',
latent_channels=4, latent_channels=4,
downsampling_factor=8, downsampling_factor=8,
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
@ -163,13 +165,15 @@ The vast majority of these arguments default to reasonable values.
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
''' """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...] of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to T2I and prompt2image() Optional named arguments are the same as those passed to T2I and prompt2image()
''' """
results = self.prompt2image(prompt, **kwargs) results = self.prompt2image(prompt, **kwargs)
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size)) pngwriter = PngWriter(
outdir, prompt, kwargs.get('batch_size', self.batch_size)
)
for r in results: for r in results:
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
pngwriter.write_image(r[0], r[1]) pngwriter.write_image(r[0], r[1])
@ -181,10 +185,13 @@ The vast majority of these arguments default to reasonable values.
def img2img(self, prompt, **kwargs): def img2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir', 'outputs/img-samples') outdir = kwargs.get('outdir', 'outputs/img-samples')
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' assert (
'init_img' in kwargs
), 'call to img2img() must include the init_img argument'
return self.prompt2png(prompt, outdir, **kwargs) return self.prompt2png(prompt, outdir, **kwargs)
def prompt2image(self, def prompt2image(
self,
# these are common # these are common
prompt, prompt,
batch_size=None, batch_size=None,
@ -203,8 +210,9 @@ The vast majority of these arguments default to reasonable values.
strength=None, strength=None,
gfpgan_strength=None, gfpgan_strength=None,
variants=None, variants=None,
**args): # eat up additional cruft **args,
''' ): # eat up additional cruft
"""
ldm.prompt2image() is the common entry point for txt2img() and img2img() ldm.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments: It takes the following arguments:
prompt // prompt string (no default) prompt // prompt string (no default)
@ -232,7 +240,7 @@ The vast majority of these arguments default to reasonable values.
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
to create the requested output directory, select a unique informative name for each image, and to create the requested output directory, select a unique informative name for each image, and
write the prompt into the PNG metadata. write the prompt into the PNG metadata.
''' """
steps = steps or self.steps steps = steps or self.steps
seed = seed or self.seed seed = seed or self.seed
width = width or self.width width = width or self.width
@ -243,17 +251,23 @@ The vast majority of these arguments default to reasonable values.
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength strength = strength or self.strength
model = self.load_model() # will instantiate the model or return it from cache model = (
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" self.load_model()
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' ) # will instantiate the model or return it from cache
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert (
0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]'
w = int(width / 64) * 64 w = int(width / 64) * 64
h = int(height / 64) * 64 h = int(height / 64) * 64
if h != height or w != width: if h != height or w != width:
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}') print(
f'Height and width must be multiples of 64. Resizing to {h}x{w}'
)
height = h height = h
width = w width = w
scope = autocast if self.precision=="autocast" else nullcontext scope = autocast if self.precision == 'autocast' else nullcontext
tic = time.time() tic = time.time()
results = list() results = list()
@ -261,30 +275,44 @@ The vast majority of these arguments default to reasonable values.
try: try:
if init_img: if init_img:
assert os.path.exists(init_img), f'{init_img}: File not found' assert os.path.exists(init_img), f'{init_img}: File not found'
images_iterator = self._img2img(prompt, images_iterator = self._img2img(
prompt,
precision_scope=scope, precision_scope=scope,
batch_size=batch_size, batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, steps=steps,
cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
init_img=init_img,strength=strength) init_img=init_img,
strength=strength,
)
else: else:
images_iterator = self._txt2img(prompt, images_iterator = self._txt2img(
prompt,
precision_scope=scope, precision_scope=scope,
batch_size=batch_size, batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, steps=steps,
cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
width=width,height=height) width=width,
height=height,
)
with scope(self.device.type), self.model.ema_scope(): with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"): for n in trange(iterations, desc='Sampling'):
seed_everything(seed) seed_everything(seed)
iter_images = next(images_iterator) iter_images = next(images_iterator)
for image in iter_images: for image in iter_images:
try: try:
if gfpgan_strength > 0: if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength) image = self._run_gfpgan(
image, gfpgan_strength
)
except Exception as e: except Exception as e:
print(f"Error running GFPGAN - Your image was not enhanced.\n{e}") print(
f'Error running GFPGAN - Your image was not enhanced.\n{e}'
)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed) image_callback(image, seed)
@ -292,58 +320,77 @@ The vast majority of these arguments default to reasonable values.
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.') print(
'Partial results will be returned; if --grid was requested, nothing will be returned.'
)
except RuntimeError as e: except RuntimeError as e:
print(str(e)) print(str(e))
print('Are you sure your system has an adequate NVIDIA GPU?') print('Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time() toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) print(f'{len(results)} images generated in', '%4.2fs' % (toc - tic))
return results return results
@torch.no_grad() @torch.no_grad()
def _txt2img(self, def _txt2img(
self,
prompt, prompt,
precision_scope, precision_scope,
batch_size, batch_size,
steps,cfg_scale,ddim_eta, steps,
cfg_scale,
ddim_eta,
skip_normalize, skip_normalize,
width,height): width,
height,
):
""" """
An infinite iterator of images from the prompt. An infinite iterator of images from the prompt.
""" """
sampler = self.sampler sampler = self.sampler
while True: while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [
samples, _ = sampler.sample(S=steps, self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
]
samples, _ = sampler.sample(
S=steps,
conditioning=c, conditioning=c,
batch_size=batch_size, batch_size=batch_size,
shape=shape, shape=shape,
verbose=False, verbose=False,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=ddim_eta) eta=ddim_eta,
)
yield self._samples_to_images(samples) yield self._samples_to_images(samples)
@torch.no_grad() @torch.no_grad()
def _img2img(self, def _img2img(
self,
prompt, prompt,
precision_scope, precision_scope,
batch_size, batch_size,
steps,cfg_scale,ddim_eta, steps,
cfg_scale,
ddim_eta,
skip_normalize, skip_normalize,
init_img,strength): init_img,
strength,
):
""" """
An infinite iterator of images from the prompt and the initial image An infinite iterator of images from the prompt and the initial image
""" """
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name != 'ddim': if self.sampler_name != 'ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") print(
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
)
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
else: else:
sampler = self.sampler sampler = self.sampler
@ -351,9 +398,13 @@ The vast majority of these arguments default to reasonable values.
init_image = self._load_img(init_img).to(self.device) init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope(self.device.type): with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
t_enc = int(strength * steps) t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps") # print(f"target t_enc is {t_enc} steps")
@ -362,16 +413,23 @@ The vast majority of these arguments default to reasonable values.
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(self.device)
)
# decode it # decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, samples = sampler.decode(
unconditional_conditioning=uc,) z_enc,
c,
t_enc,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
yield self._samples_to_images(samples) yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed? # TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, batch_size, skip_normalize): def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
uc = self.model.get_learned_conditioning(batch_size * [""]) uc = self.model.get_learned_conditioning(batch_size * [''])
# weighted sub-prompts # weighted sub-prompts
subprompts, weights = T2I._split_weighted_subprompts(prompt) subprompts, weights = T2I._split_weighted_subprompts(prompt)
@ -385,7 +443,13 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i] weight = weights[i]
if not skip_normalize: if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight) c = torch.add(
c,
self.model.get_learned_conditioning(
batch_size * [subprompts[i]]
),
alpha=weight,
)
else: # just standard 1 prompt else: # just standard 1 prompt
c = self.model.get_learned_conditioning(batch_size * [prompt]) c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c) return (uc, c)
@ -395,7 +459,9 @@ The vast majority of these arguments default to reasonable values.
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list() images = list()
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), 'c h w -> h w c'
)
image = Image.fromarray(x_sample.astype(np.uint8)) image = Image.fromarray(x_sample.astype(np.uint8))
images.append(image) images.append(image)
return images return images
@ -410,7 +476,11 @@ The vast majority of these arguments default to reasonable values.
seed_everything(self.seed) seed_everything(self.seed)
try: try:
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu") self.device = (
torch.device(self.device)
if torch.cuda.is_available()
else torch.device('cpu')
)
model = self._load_model_from_config(config, self.weights) model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None: if self.embedding_path is not None:
model.embedding_manager.load(self.embedding_path) model.embedding_manager.load(self.embedding_path)
@ -426,13 +496,21 @@ The vast majority of these arguments default to reasonable values.
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device) self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a': elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2': elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device) self.sampler = KSampler(
self.model, 'dpm_2', device=self.device
)
elif self.sampler_name == 'k_euler_a': elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler': elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device) self.sampler = KSampler(
self.model, 'euler', device=self.device
)
elif self.sampler_name == 'k_heun': elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device) self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms': elif self.sampler_name == 'k_lms':
@ -446,32 +524,38 @@ The vast majority of these arguments default to reasonable values.
return self.model return self.model
def _load_model_from_config(self, config, ckpt): def _load_model_from_config(self, config, ckpt):
print(f"Loading model from {ckpt}") print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location='cpu')
# if "global_step" in pl_sd: # if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}") # print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] sd = pl_sd['state_dict']
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.to(self.device) model.to(self.device)
model.eval() model.eval()
if self.full_precision: if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)') print(
'Using slower but more accurate full-precision math (--full_precision)'
)
else: else:
print('Using half precision math. Call with --full_precision to use slower but more accurate full precision.') print(
'Using half precision math. Call with --full_precision to use slower but more accurate full precision.'
)
model.half() model.half()
return model return model
def _load_img(self, path): def _load_img(self, path):
image = Image.open(path).convert("RGB") image = Image.open(path).convert('RGB')
w, h = image.size w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}") print(f'loaded input image of size ({w}, {h}) from {path}')
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(
lambda x: x - x % 32, (w, h)
) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS) image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.0 * image - 1.0
def _split_weighted_subprompts(text): def _split_weighted_subprompts(text):
""" """
@ -484,23 +568,25 @@ The vast majority of these arguments default to reasonable values.
prompts = [] prompts = []
weights = [] weights = []
while remaining > 0: while remaining > 0:
if ":" in text: if ':' in text:
idx = text.index(":") # first occurrence from start idx = text.index(':') # first occurrence from start
# grab up to index as sub-prompt # grab up to index as sub-prompt
prompt = text[:idx] prompt = text[:idx]
remaining -= idx remaining -= idx
# remove from main text # remove from main text
text = text[idx + 1 :] text = text[idx + 1 :]
# find value for weight # find value for weight
if " " in text: if ' ' in text:
idx = text.index(" ") # first occurence idx = text.index(' ') # first occurence
else: # no space, read to end else: # no space, read to end
idx = len(text) idx = len(text)
if idx != 0: if idx != 0:
try: try:
weight = float(text[:idx]) weight = float(text[:idx])
except: # couldn't treat as float except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") print(
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
)
weight = 1.0 weight = 1.0
else: # no value found else: # no value found
weight = 1.0 weight = 1.0
@ -519,13 +605,20 @@ The vast majority of these arguments default to reasonable values.
return prompts, weights return prompts, weights
def _run_gfpgan(self, image, strength): def _run_gfpgan(self, image, strength):
if (self.gfpgan is None): if self.gfpgan is None:
print(f"GFPGAN not initialized, it must be loaded via the --gfpgan argument") print(
f'GFPGAN not initialized, it must be loaded via the --gfpgan argument'
)
return image return image
image = image.convert("RGB") image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if strength < 1.0: if strength < 1.0:

View File

@ -13,22 +13,25 @@ from queue import Queue
from inspect import isfunction from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot
b = len(xc) b = len(xc)
txts = list() txts = list()
for bi in range(b): for bi in range(b):
txt = Image.new("RGB", wh, color="white") txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt) draw = ImageDraw.Draw(txt)
font = ImageFont.load_default() font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256)) nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) lines = '\n'.join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try: try:
draw.text((0, 0), lines, fill="black", font=font) draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError: except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.") print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt) txts.append(txt)
@ -70,22 +73,26 @@ def mean_flat(tensor):
def count_params(model, verbose=False): def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
if verbose: if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params return total_params
def instantiate_from_config(config, **kwargs): def instantiate_from_config(config, **kwargs):
if not "target" in config: if not 'target' in config:
if config == '__is_first_stage__': if config == '__is_first_stage__':
return None return None
elif config == "__is_unconditional__": elif config == '__is_unconditional__':
return None return None
raise KeyError("Expected key `target` to instantiate.") raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
def get_obj_from_str(string, reload=False): def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1) module, cls = string.rsplit('.', 1)
if reload: if reload:
module_imp = importlib.import_module(module) module_imp = importlib.import_module(module)
importlib.reload(module_imp) importlib.reload(module_imp)
@ -101,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else: else:
res = func(data) res = func(data)
Q.put([idx, res]) Q.put([idx, res])
Q.put("Done") Q.put('Done')
def parallel_data_prefetch( def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False func: callable,
data,
n_proc,
target_data_type='ndarray',
cpu_intensive=True,
use_worker_id=False,
): ):
# if target_data_type not in ["ndarray", "list"]: # if target_data_type not in ["ndarray", "list"]:
# raise ValueError( # raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# ) # )
if isinstance(data, np.ndarray) and target_data_type == "list": if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError("list expected but function got ndarray.") raise ValueError('list expected but function got ndarray.')
elif isinstance(data, abc.Iterable): elif isinstance(data, abc.Iterable):
if isinstance(data, dict): if isinstance(data, dict):
print( print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
) )
data = list(data.values()) data = list(data.values())
if target_data_type == "ndarray": if target_data_type == 'ndarray':
data = np.asarray(data) data = np.asarray(data)
else: else:
data = list(data) data = list(data)
else: else:
raise TypeError( raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
) )
if cpu_intensive: if cpu_intensive:
@ -135,7 +147,7 @@ def parallel_data_prefetch(
Q = Queue(1000) Q = Queue(1000)
proc = Thread proc = Thread
# spawn processes # spawn processes
if target_data_type == "ndarray": if target_data_type == 'ndarray':
arguments = [ arguments = [
[func, Q, part, i, use_worker_id] [func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc)) for i, part in enumerate(np.array_split(data, n_proc))
@ -158,7 +170,7 @@ def parallel_data_prefetch(
processes += [p] processes += [p]
# start processes # start processes
print(f"Start prefetching...") print(f'Start prefetching...')
import time import time
start = time.time() start = time.time()
@ -171,13 +183,13 @@ def parallel_data_prefetch(
while k < n_proc: while k < n_proc:
# get result # get result
res = Q.get() res = Q.get()
if res == "Done": if res == 'Done':
k += 1 k += 1
else: else:
gather_res[res[0]] = res[1] gather_res[res[0]] = res[1]
except Exception as e: except Exception as e:
print("Exception: ", e) print('Exception: ', e)
for p in processes: for p in processes:
p.terminate() p.terminate()
@ -185,7 +197,7 @@ def parallel_data_prefetch(
finally: finally:
for p in processes: for p in processes:
p.join() p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]") print(f'Prefetching complete. [{time.time() - start} sec.]')
if target_data_type == 'ndarray': if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray): if not isinstance(gather_res[0], np.ndarray):

659
main.py

File diff suppressed because it is too large Load Diff

View File

@ -12,37 +12,41 @@ from ldm.dream.pngwriter import PngWriter,PromptFormatter
debugging = False debugging = False
def main(): def main():
''' Initialize command-line parsers and the diffusion model ''' """Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser() arg_parser = create_argv_parser()
opt = arg_parser.parse_args() opt = arg_parser.parse_args()
if opt.laion400m: if opt.laion400m:
# defaults suitable to the older latent diffusion weights # defaults suitable to the older latent diffusion weights
width = 256 width = 256
height = 256 height = 256
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml'
weights = "models/ldm/text2img-large/model.ckpt" weights = 'models/ldm/text2img-large/model.ckpt'
else: else:
# some defaults suitable for stable diffusion weights # some defaults suitable for stable diffusion weights
width = 512 width = 512
height = 512 height = 512
config = "configs/stable-diffusion/v1-inference.yaml" config = 'configs/stable-diffusion/v1-inference.yaml'
weights = "models/ldm/stable-diffusion-v1/model.ckpt" weights = 'models/ldm/stable-diffusion-v1/model.ckpt'
print("* Initializing, be patient...\n") print('* Initializing, be patient...\n')
sys.path.append('.') sys.path.append('.')
from pytorch_lightning import logging from pytorch_lightning import logging
from ldm.simplet2i import T2I from ldm.simplet2i import T2I
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
import transformers import transformers
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# creating a simple text2image object with a handful of # creating a simple text2image object with a handful of
# defaults passed on the command line. # defaults passed on the command line.
# additional parameters will be added (or overriden) during # additional parameters will be added (or overriden) during
# the user input loop # the user input loop
t2i = T2I(width=width, t2i = T2I(
width=width,
height=height, height=height,
sampler_name=opt.sampler_name, sampler_name=opt.sampler_name,
weights=weights, weights=weights,
@ -50,7 +54,7 @@ def main():
config=config, config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
embedding_path=opt.embedding_path, embedding_path=opt.embedding_path,
device=opt.device device=opt.device,
) )
# make sure the output directory exists # make sure the output directory exists
@ -58,7 +62,7 @@ def main():
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed # gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
infile = None infile = None
try: try:
@ -73,27 +77,42 @@ def main():
# load GFPGAN if requested # load GFPGAN if requested
if opt.use_gfpgan: if opt.use_gfpgan:
print("\n* --gfpgan was specified, loading gfpgan...") print('\n* --gfpgan was specified, loading gfpgan...')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
try: try:
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path) model_path = os.path.join(
opt.gfpgan_dir, opt.gfpgan_model_path
)
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path) raise Exception(
'GFPGAN model not found at path ' + model_path
)
sys.path.append(os.path.abspath(opt.gfpgan_dir)) sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer from gfpgan import GFPGANer
bg_upsampler = load_gfpgan_bg_upsampler(opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile) bg_upsampler = load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile
)
t2i.gfpgan = GFPGANer(model_path=model_path, upscale=opt.gfpgan_upscale, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler) t2i.gfpgan = GFPGANer(
model_path=model_path,
upscale=opt.gfpgan_upscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception: except Exception:
import traceback import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print('Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...") print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..."
)
log_path = os.path.join(opt.outdir, 'dream_log.txt') log_path = os.path.join(opt.outdir, 'dream_log.txt')
with open(log_path, 'a') as log: with open(log_path, 'a') as log:
@ -105,13 +124,13 @@ def main():
def main_loop(t2i, outdir, parser, log, infile): def main_loop(t2i, outdir, parser, log, infile):
''' prompt/read/execute loop ''' """prompt/read/execute loop"""
done = False done = False
last_seeds = [] last_seeds = []
while not done: while not done:
try: try:
command = infile.readline() if infile else input("dream> ") command = infile.readline() if infile else input('dream> ')
except EOFError: except EOFError:
done = True done = True
break break
@ -142,17 +161,19 @@ def main_loop(t2i,outdir,parser,log,infile):
if elements[0] == 'cd' and len(elements) > 1: if elements[0] == 'cd' and len(elements) > 1:
if os.path.exists(elements[1]): if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}") print(f'setting image output directory to {elements[1]}')
outdir = elements[1] outdir = elements[1]
else: else:
print(f"directory {elements[1]} does not exist") print(f'directory {elements[1]} does not exist')
continue continue
if elements[0] == 'pwd': if elements[0] == 'pwd':
print(f"current output directory is {outdir}") print(f'current output directory is {outdir}')
continue continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command if elements[0].startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
elements.pop(0) elements.pop(0)
# rearrange the arguments to mimic how it works in the Dream bot. # rearrange the arguments to mimic how it works in the Dream bot.
@ -175,14 +196,14 @@ def main_loop(t2i,outdir,parser,log,infile):
parser.print_help() parser.print_help()
continue continue
if len(opt.prompt) == 0: if len(opt.prompt) == 0:
print("Try again with a prompt!") print('Try again with a prompt!')
continue continue
if opt.seed is not None and opt.seed < 0: # retrieve previous value! if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try: try:
opt.seed = last_seeds[opt.seed] opt.seed = last_seeds[opt.seed]
print(f"reusing previous seed {opt.seed}") print(f'reusing previous seed {opt.seed}')
except IndexError: except IndexError:
print(f"No previous seed at position {opt.seed} found") print(f'No previous seed at position {opt.seed} found')
opt.seed = None opt.seed = None
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
@ -193,7 +214,9 @@ def main_loop(t2i,outdir,parser,log,infile):
callback = file_writer.write_image if individual_images else None callback = file_writer.write_image if individual_images else None
image_list = t2i.prompt2image(image_callback=callback, **vars(opt)) image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
results = file_writer.files_written if individual_images else image_list results = (
file_writer.files_written if individual_images else image_list
)
if opt.grid and len(results) > 0: if opt.grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results]) grid_img = file_writer.make_grid([r[0] for r in results])
@ -201,7 +224,9 @@ def main_loop(t2i,outdir,parser,log,infile):
seeds = [a[1] for a in results] seeds = [a[1] for a in results]
results = [[filename, seeds]] results = [[filename, seeds]]
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}' metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename) file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)
last_seeds = [r[1] for r in results] last_seeds = [r[1] for r in results]
@ -213,10 +238,11 @@ def main_loop(t2i,outdir,parser,log,infile):
print(e) print(e)
continue continue
print("Outputs:") print('Outputs:')
write_log_message(t2i, normalized_prompt, results, log) write_log_message(t2i, normalized_prompt, results, log)
print("goodbye!") print('goodbye!')
def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400): def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
import torch import torch
@ -224,13 +250,24 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
if bg_upsampler == 'realesrgan': if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU if not torch.cuda.is_available(): # CPU
import warnings import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.') warnings.warn(
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.'
)
bg_upsampler = None bg_upsampler = None
else: else:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer( bg_upsampler = RealESRGANer(
scale=2, scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
@ -238,12 +275,14 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
tile=bg_tile, tile=bg_tile,
tile_pad=10, tile_pad=10,
pre_pad=0, pre_pad=0,
half=True) # need to set False in CPU mode half=True,
) # need to set False in CPU mode
else: else:
bg_upsampler = None bg_upsampler = None
return bg_upsampler return bg_upsampler
# variant generation is going to be superseded by a generalized # variant generation is going to be superseded by a generalized
# "prompt-morph" functionality # "prompt-morph" functionality
# def generate_variants(t2i,outdir,opt,previous_gens): # def generate_variants(t2i,outdir,opt,previous_gens):
@ -271,107 +310,206 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
def write_log_message(t2i, prompt, results, logfile): def write_log_message(t2i, prompt, results, logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata''' """logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata"""
last_seed = None last_seed = None
img_num = 1 img_num = 1
seenit = {} seenit = {}
for r in results: for r in results:
seed = r[1] seed = r[1]
log_message = (f'{r[0]}: {prompt} -S{seed}') log_message = f'{r[0]}: {prompt} -S{seed}'
print(log_message) print(log_message)
logfile.write(log_message+"\n") logfile.write(log_message + '\n')
logfile.flush() logfile.flush()
def create_argv_parser(): def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args") parser = argparse.ArgumentParser(
parser.add_argument("--laion400m", description="Parse script's command line args"
"--latent_diffusion", )
"-l", parser.add_argument(
'--laion400m',
'--latent_diffusion',
'-l',
dest='laion400m', dest='laion400m',
action='store_true', action='store_true',
help="fallback to the latent diffusion (laion400m) weights and config") help='fallback to the latent diffusion (laion400m) weights and config',
parser.add_argument("--from_file", )
parser.add_argument(
'--from_file',
dest='infile', dest='infile',
type=str, type=str,
help="if specified, load prompts from this file") help='if specified, load prompts from this file',
parser.add_argument('-n','--iterations', )
parser.add_argument(
'-n',
'--iterations',
type=int, type=int,
default=1, default=1,
help="number of images to generate") help='number of images to generate',
parser.add_argument('-F','--full_precision', )
parser.add_argument(
'-F',
'--full_precision',
dest='full_precision', dest='full_precision',
action='store_true', action='store_true',
help="use slower full precision math for calculations") help='use slower full precision math for calculations',
parser.add_argument('--sampler','-m', )
dest="sampler_name", parser.add_argument(
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], '--sampler',
'-m',
dest='sampler_name',
choices=[
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
],
default='k_lms', default='k_lms',
help="which sampler to use (k_lms) - can only be set on command line") help='which sampler to use (k_lms) - can only be set on command line',
parser.add_argument('--outdir', )
parser.add_argument(
'--outdir',
'-o', '-o',
type=str, type=str,
default="outputs/img-samples", default='outputs/img-samples',
help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples") help='directory in which to place generated images and a log of prompts and seeds (outputs/img-samples',
parser.add_argument('--embedding_path', )
parser.add_argument(
'--embedding_path',
type=str, type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line") help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
parser.add_argument('--device', )
parser.add_argument(
'--device',
'-d', '-d',
type=str, type=str,
default="cuda", default='cuda',
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible") help='device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible',
)
# GFPGAN related args # GFPGAN related args
parser.add_argument('--gfpgan', parser.add_argument(
'--gfpgan',
dest='use_gfpgan', dest='use_gfpgan',
action='store_true', action='store_true',
help="load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory") help='load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory',
parser.add_argument("--gfpgan_upscale", )
parser.add_argument(
'--gfpgan_upscale',
type=int, type=int,
default=2, default=2,
help="The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified") help='The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified',
parser.add_argument("--gfpgan_bg_upsampler", )
parser.add_argument(
'--gfpgan_bg_upsampler',
type=str, type=str,
default='realesrgan', default='realesrgan',
help="Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified") help='Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified',
parser.add_argument("--gfpgan_bg_tile", )
parser.add_argument(
'--gfpgan_bg_tile',
type=int, type=int,
default=400, default=400,
help="Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified") help='Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified',
parser.add_argument("--gfpgan_model_path", )
parser.add_argument(
'--gfpgan_model_path',
type=str, type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth', default='experiments/pretrained_models/GFPGANv1.3.pth',
help="indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified") help='indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified',
parser.add_argument("--gfpgan_dir", )
parser.add_argument(
'--gfpgan_dir',
type=str, type=str,
default='../GFPGAN', default='../GFPGAN',
help="indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified") help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified',
)
return parser return parser
def create_cmd_parser(): def create_cmd_parser():
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12') parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
parser.add_argument('prompt') parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-s', '--steps', type=int, help='number of steps')
parser.add_argument('-S','--seed',type=int,help="image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc") parser.add_argument(
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform (slower, but will provide seeds for individual images)") '-S',
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (will not provide seeds for individual images!)") '--seed',
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") type=int,
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") help='image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale") )
parser.add_argument('-g','--grid',action='store_true',help="generate a grid") parser.add_argument(
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") '-n',
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)") '--iterations',
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") type=int,
parser.add_argument('-G','--gfpgan_strength', default=0.5, type=float, help="The strength at which to apply the GFPGAN model to the result, in order to improve faces.") default=1,
help='number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-b',
'--batch_size',
type=int,
default=1,
help='number of images to produce per sampling (will not provide seeds for individual images!)',
)
parser.add_argument(
'-W', '--width', type=int, help='image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='prompt configuration scale',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0.5,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
# variants is going to be superseded by a generalized "prompt-morph" function # variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") # parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='skip subprompt weight normalization',
)
return parser return parser
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -11,17 +11,18 @@ import warnings
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
print("preloading bert tokenizer...") print('preloading bert tokenizer...')
from transformers import BertTokenizerFast from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success") tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
# this will download requirements for Kornia # this will download requirements for Kornia
print("preloading Kornia requirements (ignore the deprecation warnings)...") print('preloading Kornia requirements (ignore the deprecation warnings)...')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia import kornia
print("...success") print('...success')
version = 'openai/clip-vit-large-patch14' version = 'openai/clip-vit-large-patch14'
@ -29,6 +30,7 @@ print('preloading CLIP model (Ignore the deprecation warnings)...')
sys.stdout.flush() sys.stdout.flush()
import clip import clip
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
tokenizer = CLIPTokenizer.from_pretrained(version) tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version) transformer = CLIPTextModel.from_pretrained(version)
print('\n\n...success') print('\n\n...success')
@ -38,23 +40,33 @@ print('\n\n...success')
gfpgan = False gfpgan = False
try: try:
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
gfpgan = True gfpgan = True
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
if gfpgan: if gfpgan:
print("Loading models from RealESRGAN and facexlib") print('Loading models from RealESRGAN and facexlib')
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(scale=2,
RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)) model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
FaceRestoreHelper(1, det_model='retinaface_resnet50') FaceRestoreHelper(1, det_model='retinaface_resnet50')
print("...success") print('...success')
except Exception: except Exception:
import traceback import traceback
print("Error loading GFPGAN:")
print('Error loading GFPGAN:')
print(traceback.format_exc()) print(traceback.format_exc())