prettified all the code using "blue" at the urging of @tildebyte

This commit is contained in:
Lincoln Stein 2022-08-26 03:15:42 -04:00
parent dd670200bb
commit 4f02b72c9c
35 changed files with 6252 additions and 3119 deletions

View File

@ -1,11 +1,17 @@
from abc import abstractmethod from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset from torch.utils.data import (
Dataset,
ConcatDataset,
ChainDataset,
IterableDataset,
)
class Txt2ImgIterableBaseDataset(IterableDataset): class Txt2ImgIterableBaseDataset(IterableDataset):
''' """
Define an interface to make the IterableDatasets for text2img data chainable Define an interface to make the IterableDatasets for text2img data chainable
''' """
def __init__(self, num_records=0, valid_ids=None, size=256): def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__() super().__init__()
self.num_records = num_records self.num_records = num_records
@ -13,11 +19,13 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids self.sample_ids = valid_ids
self.size = size self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self): def __len__(self):
return self.num_records return self.num_records
@abstractmethod @abstractmethod
def __iter__(self): def __iter__(self):
pass pass

View File

@ -11,24 +11,34 @@ from tqdm import tqdm
from torch.utils.data import Dataset, Subset from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve from taming.data.imagenet import (
str_to_indices,
give_synsets_from_indices,
download,
retrieve,
)
from taming.data.imagenet import ImagePaths from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from ldm.modules.image_degradation import (
degradation_fn_bsr,
degradation_fn_bsr_light,
)
def synset2idx(path_to_yaml="data/index_synset.yaml"): def synset2idx(path_to_yaml='data/index_synset.yaml'):
with open(path_to_yaml) as f: with open(path_to_yaml) as f:
di2s = yaml.load(f) di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items()) return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset): class ImageNetBase(Dataset):
def __init__(self, config=None): def __init__(self, config=None):
self.config = config or OmegaConf.create() self.config = config or OmegaConf.create()
if not type(self.config)==dict: if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config) self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) self.keep_orig_class_label = self.config.get(
'keep_orig_class_label', False
)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare() self._prepare()
self._prepare_synset_to_human() self._prepare_synset_to_human()
@ -46,17 +56,23 @@ class ImageNetBase(Dataset):
raise NotImplementedError() raise NotImplementedError()
def _filter_relpaths(self, relpaths): def _filter_relpaths(self, relpaths):
ignore = set([ ignore = set(
"n06596364_9591.JPEG", [
]) 'n06596364_9591.JPEG',
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] ]
if "sub_indices" in self.config: )
indices = str_to_indices(self.config["sub_indices"]) relpaths = [
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
]
if 'sub_indices' in self.config:
indices = str_to_indices(self.config['sub_indices'])
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = [] files = []
for rpath in relpaths: for rpath in relpaths:
syn = rpath.split("/")[0] syn = rpath.split('/')[0]
if syn in synsets: if syn in synsets:
files.append(rpath) files.append(rpath)
return files return files
@ -65,78 +81,89 @@ class ImageNetBase(Dataset):
def _prepare_synset_to_human(self): def _prepare_synset_to_human(self):
SIZE = 2655750 SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
self.human_dict = os.path.join(self.root, "synset_human.txt") self.human_dict = os.path.join(self.root, 'synset_human.txt')
if (not os.path.exists(self.human_dict) or if (
not os.path.getsize(self.human_dict)==SIZE): not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
download(URL, self.human_dict) download(URL, self.human_dict)
def _prepare_idx_to_synset(self): def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
self.idx2syn = os.path.join(self.root, "index_synset.yaml") self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
if (not os.path.exists(self.idx2syn)): if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn) download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self): def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") self.human2integer = os.path.join(
if (not os.path.exists(self.human2integer)): self.root, 'imagenet1000_clsidx_to_labels.txt'
)
if not os.path.exists(self.human2integer):
download(URL, self.human2integer) download(URL, self.human2integer)
with open(self.human2integer, "r") as f: with open(self.human2integer, 'r') as f:
lines = f.read().splitlines() lines = f.read().splitlines()
assert len(lines) == 1000 assert len(lines) == 1000
self.human2integer_dict = dict() self.human2integer_dict = dict()
for line in lines: for line in lines:
value, key = line.split(":") value, key = line.split(':')
self.human2integer_dict[key] = int(value) self.human2integer_dict[key] = int(value)
def _load(self): def _load(self):
with open(self.txt_filelist, "r") as f: with open(self.txt_filelist, 'r') as f:
self.relpaths = f.read().splitlines() self.relpaths = f.read().splitlines()
l1 = len(self.relpaths) l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) print(
'Removed {} files from filelist during filtering.'.format(
l1 - len(self.relpaths)
)
)
self.synsets = [p.split("/")[0] for p in self.relpaths] self.synsets = [p.split('/')[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets) unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) class_dict = dict(
(synset, i) for i, synset in enumerate(unique_synsets)
)
if not self.keep_orig_class_label: if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets] self.class_labels = [class_dict[s] for s in self.synsets]
else: else:
self.class_labels = [self.synset2idx[s] for s in self.synsets] self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f: with open(self.human_dict, 'r') as f:
human_dict = f.read().splitlines() human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict) human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets] self.human_labels = [human_dict[s] for s in self.synsets]
labels = { labels = {
"relpath": np.array(self.relpaths), 'relpath': np.array(self.relpaths),
"synsets": np.array(self.synsets), 'synsets': np.array(self.synsets),
"class_label": np.array(self.class_labels), 'class_label': np.array(self.class_labels),
"human_label": np.array(self.human_labels), 'human_label': np.array(self.human_labels),
} }
if self.process_images: if self.process_images:
self.size = retrieve(self.config, "size", default=256) self.size = retrieve(self.config, 'size', default=256)
self.data = ImagePaths(self.abspaths, self.data = ImagePaths(
labels=labels, self.abspaths,
size=self.size, labels=labels,
random_crop=self.random_crop, size=self.size,
) random_crop=self.random_crop,
)
else: else:
self.data = self.abspaths self.data = self.abspaths
class ImageNetTrain(ImageNetBase): class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train" NAME = 'ILSVRC2012_train'
URL = "http://www.image-net.org/challenges/LSVRC/2012/" URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
FILES = [ FILES = [
"ILSVRC2012_img_train.tar", 'ILSVRC2012_img_train.tar',
] ]
SIZES = [ SIZES = [
147897477120, 147897477120,
@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase):
if self.data_root: if self.data_root:
self.root = os.path.join(self.data_root, self.NAME) self.root = os.path.join(self.data_root, self.NAME)
else: else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) cachedir = os.environ.get(
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, "data") self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, "filelist.txt") self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 1281167 self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", self.random_crop = retrieve(
default=True) self.config, 'ImageNetTrain/random_crop', default=True
)
if not tdu.is_prepared(self.root): if not tdu.is_prepared(self.root):
# prep # prep
print("Preparing dataset {} in {}".format(self.NAME, self.root)) print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir datadir = self.datadir
if not os.path.exists(datadir): if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0]) path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root) atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path assert atpath == path
print("Extracting {} to {}".format(path, datadir)) print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True) os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar: with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir) tar.extractall(path=datadir)
print("Extracting sub-tars.") print('Extracting sub-tars.')
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
for subpath in tqdm(subpaths): for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")] subdir = subpath[: -len('.tar')]
os.makedirs(subdir, exist_ok=True) os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar: with tarfile.open(subpath, 'r:') as tar:
tar.extractall(path=subdir) tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist) filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n" filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, "w") as f: with open(self.txt_filelist, 'w') as f:
f.write(filelist) f.write(filelist)
tdu.mark_prepared(self.root) tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase): class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation" NAME = 'ILSVRC2012_validation'
URL = "http://www.image-net.org/challenges/LSVRC/2012/" URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
FILES = [ FILES = [
"ILSVRC2012_img_val.tar", 'ILSVRC2012_img_val.tar',
"validation_synset.txt", 'validation_synset.txt',
] ]
SIZES = [ SIZES = [
6744924160, 6744924160,
@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase):
if self.data_root: if self.data_root:
self.root = os.path.join(self.data_root, self.NAME) self.root = os.path.join(self.data_root, self.NAME)
else: else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) cachedir = os.environ.get(
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
self.datadir = os.path.join(self.root, "data") )
self.txt_filelist = os.path.join(self.root, "filelist.txt") self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 50000 self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", self.random_crop = retrieve(
default=False) self.config, 'ImageNetValidation/random_crop', default=False
)
if not tdu.is_prepared(self.root): if not tdu.is_prepared(self.root):
# prep # prep
print("Preparing dataset {} in {}".format(self.NAME, self.root)) print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir datadir = self.datadir
if not os.path.exists(datadir): if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0]) path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root) atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path assert atpath == path
print("Extracting {} to {}".format(path, datadir)) print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True) os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar: with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir) tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1]) vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
download(self.VS_URL, vspath) download(self.VS_URL, vspath)
with open(vspath, "r") as f: with open(vspath, 'r') as f:
synset_dict = f.read().splitlines() synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict) synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders") print('Reorganizing into synset folders')
synsets = np.unique(list(synset_dict.values())) synsets = np.unique(list(synset_dict.values()))
for s in synsets: for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True) os.makedirs(os.path.join(datadir, s), exist_ok=True)
@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase):
dst = os.path.join(datadir, v) dst = os.path.join(datadir, v)
shutil.move(src, dst) shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist) filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n" filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, "w") as f: with open(self.txt_filelist, 'w') as f:
f.write(filelist) f.write(filelist)
tdu.mark_prepared(self.root) tdu.mark_prepared(self.root)
class ImageNetSR(Dataset): class ImageNetSR(Dataset):
def __init__(self, size=None, def __init__(
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., self,
random_crop=True): size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
""" """
Imagenet Superresolution Dataloader Imagenet Superresolution Dataloader
Performs following ops in order: Performs following ops in order:
@ -296,67 +345,86 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f) self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.) assert max_crop_f <= 1.0
self.center_crop = not random_crop self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
if degradation == "bsrgan": if degradation == 'bsrgan':
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) self.degradation_process = partial(
degradation_fn_bsr, sf=downscale_f
)
elif degradation == "bsrgan_light": elif degradation == 'bsrgan_light':
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) self.degradation_process = partial(
degradation_fn_bsr_light, sf=downscale_f
)
else: else:
interpolation_fn = { interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST, 'cv_nearest': cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR, 'cv_bilinear': cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC, 'cv_bicubic': cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA, 'cv_area': cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4, 'cv_lanczos': cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST, 'pil_nearest': PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR, 'pil_bilinear': PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC, 'pil_bicubic': PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX, 'pil_box': PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING, 'pil_hamming': PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS, 'pil_lanczos': PIL.Image.LANCZOS,
}[degradation] }[degradation]
self.pil_interpolation = degradation.startswith("pil_") self.pil_interpolation = degradation.startswith('pil_')
if self.pil_interpolation: if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) self.degradation_process = partial(
TF.resize,
size=self.LR_size,
interpolation=interpolation_fn,
)
else: else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, self.degradation_process = albumentations.SmallestMaxSize(
interpolation=interpolation_fn) max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self): def __len__(self):
return len(self.base) return len(self.base)
def __getitem__(self, i): def __getitem__(self, i):
example = self.base[i] example = self.base[i]
image = Image.open(example["file_path_"]) image = Image.open(example['file_path_'])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2]) min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
crop_side_len = int(crop_side_len) crop_side_len = int(crop_side_len)
if self.center_crop: if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
else: else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
image = self.cropper(image=image)["image"] image = self.cropper(image=image)['image']
image = self.image_rescaler(image=image)["image"] image = self.image_rescaler(image=image)['image']
if self.pil_interpolation: if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image) image_pil = PIL.Image.fromarray(image)
@ -364,10 +432,10 @@ class ImageNetSR(Dataset):
LR_image = np.array(LR_image).astype(np.uint8) LR_image = np.array(LR_image).astype(np.uint8)
else: else:
LR_image = self.degradation_process(image=image)["image"] LR_image = self.degradation_process(image=image)['image']
example["image"] = (image/127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example return example
@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_base(self): def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f: with open('data/imagenet_train_hr_indices.p', 'rb') as f:
indices = pickle.load(f) indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,) dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices) return Subset(dset, indices)
@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_base(self): def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f: with open('data/imagenet_val_hr_indices.p', 'rb') as f:
indices = pickle.load(f) indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,) dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices) return Subset(dset, indices)

View File

@ -7,30 +7,33 @@ from torchvision import transforms
class LSUNBase(Dataset): class LSUNBase(Dataset):
def __init__(self, def __init__(
txt_file, self,
data_root, txt_file,
size=None, data_root,
interpolation="bicubic", size=None,
flip_p=0.5 interpolation='bicubic',
): flip_p=0.5,
):
self.data_paths = txt_file self.data_paths = txt_file
self.data_root = data_root self.data_root = data_root
with open(self.data_paths, "r") as f: with open(self.data_paths, 'r') as f:
self.image_paths = f.read().splitlines() self.image_paths = f.read().splitlines()
self._length = len(self.image_paths) self._length = len(self.image_paths)
self.labels = { self.labels = {
"relative_file_path_": [l for l in self.image_paths], 'relative_file_path_': [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l) 'file_path_': [
for l in self.image_paths], os.path.join(self.data_root, l) for l in self.image_paths
],
} }
self.size = size self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR, self.interpolation = {
"bilinear": PIL.Image.BILINEAR, 'linear': PIL.Image.LINEAR,
"bicubic": PIL.Image.BICUBIC, 'bilinear': PIL.Image.BILINEAR,
"lanczos": PIL.Image.LANCZOS, 'bicubic': PIL.Image.BICUBIC,
}[interpolation] 'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self): def __len__(self):
@ -38,55 +41,86 @@ class LSUNBase(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels) example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"]) image = Image.open(example['file_path_'])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1] h, w, = (
img = img[(h - crop) // 2:(h + crop) // 2, img.shape[0],
(w - crop) // 2:(w + crop) // 2] img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation) image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example return example
class LSUNChurchesTrain(LSUNBase): class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) super().__init__(
txt_file='data/lsun/church_outdoor_train.txt',
data_root='data/lsun/churches',
**kwargs
)
class LSUNChurchesValidation(LSUNBase): class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", super().__init__(
flip_p=flip_p, **kwargs) txt_file='data/lsun/church_outdoor_val.txt',
data_root='data/lsun/churches',
flip_p=flip_p,
**kwargs
)
class LSUNBedroomsTrain(LSUNBase): class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) super().__init__(
txt_file='data/lsun/bedrooms_train.txt',
data_root='data/lsun/bedrooms',
**kwargs
)
class LSUNBedroomsValidation(LSUNBase): class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", super().__init__(
flip_p=flip_p, **kwargs) txt_file='data/lsun/bedrooms_val.txt',
data_root='data/lsun/bedrooms',
flip_p=flip_p,
**kwargs
)
class LSUNCatsTrain(LSUNBase): class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) super().__init__(
txt_file='data/lsun/cat_train.txt',
data_root='data/lsun/cats',
**kwargs
)
class LSUNCatsValidation(LSUNBase): class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", super().__init__(
flip_p=flip_p, **kwargs) txt_file='data/lsun/cat_val.txt',
data_root='data/lsun/cats',
flip_p=flip_p,
**kwargs
)

View File

@ -72,31 +72,57 @@ imagenet_dual_templates_small = [
] ]
per_img_token_list = [ per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
] ]
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, def __init__(
data_root, self,
size=None, data_root,
repeats=100, size=None,
interpolation="bicubic", repeats=100,
flip_p=0.5, interpolation='bicubic',
set="train", flip_p=0.5,
placeholder_token="*", set='train',
per_image_tokens=False, placeholder_token='*',
center_crop=False, per_image_tokens=False,
mixing_prob=0.25, center_crop=False,
coarse_class_text=None, mixing_prob=0.25,
): coarse_class_text=None,
):
self.data_root = data_root self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths) # self._length = len(self.image_paths)
self.num_images = len(self.image_paths) self.num_images = len(self.image_paths)
self._length = self.num_images self._length = self.num_images
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
@ -107,17 +133,20 @@ class PersonalizedBase(Dataset):
self.coarse_class_text = coarse_class_text self.coarse_class_text = coarse_class_text
if per_image_tokens: if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train": if set == 'train':
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.size = size self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR, self.interpolation = {
"bilinear": PIL.Image.BILINEAR, 'linear': PIL.Image.LINEAR,
"bicubic": PIL.Image.BICUBIC, 'bilinear': PIL.Image.BILINEAR,
"lanczos": PIL.Image.LANCZOS, 'bicubic': PIL.Image.BICUBIC,
}[interpolation] 'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self): def __len__(self):
@ -127,34 +156,47 @@ class PersonalizedBase(Dataset):
example = {} example = {}
image = Image.open(self.image_paths[i % self.num_images]) image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
placeholder_string = self.placeholder_token placeholder_string = self.placeholder_token
if self.coarse_class_text: if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}" placeholder_string = (
f'{self.coarse_class_text} {placeholder_string}'
)
if self.per_image_tokens and np.random.uniform() < self.mixing_prob: if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else: else:
text = random.choice(imagenet_templates_small).format(placeholder_string) text = random.choice(imagenet_templates_small).format(
placeholder_string
example["caption"] = text )
example['caption'] = text
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1] h, w, = (
img = img[(h - crop) // 2:(h + crop) // 2, img.shape[0],
(w - crop) // 2:(w + crop) // 2] img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation) image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example return example

View File

@ -50,29 +50,55 @@ imagenet_dual_templates_small = [
] ]
per_img_token_list = [ per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
] ]
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, def __init__(
data_root, self,
size=None, data_root,
repeats=100, size=None,
interpolation="bicubic", repeats=100,
flip_p=0.5, interpolation='bicubic',
set="train", flip_p=0.5,
placeholder_token="*", set='train',
per_image_tokens=False, placeholder_token='*',
center_crop=False, per_image_tokens=False,
): center_crop=False,
):
self.data_root = data_root self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths) # self._length = len(self.image_paths)
self.num_images = len(self.image_paths) self.num_images = len(self.image_paths)
self._length = self.num_images self._length = self.num_images
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
@ -80,17 +106,20 @@ class PersonalizedBase(Dataset):
self.center_crop = center_crop self.center_crop = center_crop
if per_image_tokens: if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train": if set == 'train':
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.size = size self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR, self.interpolation = {
"bilinear": PIL.Image.BILINEAR, 'linear': PIL.Image.LINEAR,
"bicubic": PIL.Image.BICUBIC, 'bilinear': PIL.Image.BILINEAR,
"lanczos": PIL.Image.LANCZOS, 'bicubic': PIL.Image.BICUBIC,
}[interpolation] 'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self): def __len__(self):
@ -100,30 +129,41 @@ class PersonalizedBase(Dataset):
example = {} example = {}
image = Image.open(self.image_paths[i % self.num_images]) image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB": if not image.mode == 'RGB':
image = image.convert("RGB") image = image.convert('RGB')
if self.per_image_tokens and np.random.uniform() < 0.25: if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) text = random.choice(imagenet_dual_templates_small).format(
self.placeholder_token, per_img_token_list[i % self.num_images]
)
else: else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token) text = random.choice(imagenet_templates_small).format(
self.placeholder_token
example["caption"] = text )
example['caption'] = text
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1] h, w, = (
img = img[(h - crop) // 2:(h + crop) // 2, img.shape[0],
(w - crop) // 2:(w + crop) // 2] img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation) image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32) example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example return example

View File

@ -1,4 +1,4 @@
''' """
Two helper classes for dealing with PNG images and their path names. Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata appropriate names for them, and writes prompt metadata
@ -7,95 +7,104 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
prompt for file/directory names. prompt for file/directory names.
PromptFormatter -- Utility for converting a Namespace of prompt parameters PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches. back into a formatted prompt string with command-line switches.
''' """
import os import os
import re import re
from math import sqrt,floor,ceil from math import sqrt, floor, ceil
from PIL import Image,PngImagePlugin from PIL import Image, PngImagePlugin
# -------------------image generation utils----- # -------------------image generation utils-----
class PngWriter: class PngWriter:
def __init__(self, outdir, prompt=None, batch_size=1):
def __init__(self,outdir,prompt=None,batch_size=1): self.outdir = outdir
self.outdir = outdir self.batch_size = batch_size
self.batch_size = batch_size self.prompt = prompt
self.prompt = prompt self.filepath = None
self.filepath = None self.files_written = []
self.files_written = []
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
def write_image(self,image,seed): def write_image(self, image, seed):
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way self.filepath = self.unique_filename(
seed, self.filepath
) # will increment name in some sensible way
try: try:
prompt = f'{self.prompt} -S{seed}' prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image,prompt,self.filepath) self.save_image_and_prompt_to_png(image, prompt, self.filepath)
except IOError as e: except IOError as e:
print(e) print(e)
self.files_written.append([self.filepath,seed]) self.files_written.append([self.filepath, seed])
def unique_filename(self,seed,previouspath=None): def unique_filename(self, seed, previouspath=None):
revision = 1 revision = 1
if previouspath is None: if previouspath is None:
# sort reverse alphabetically until we find max+1 # sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir),reverse=True) dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png # find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') filename = next(
basecount = int(filename.split('.',1)[0]) (f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(filename.split('.', 1)[0])
basecount += 1 basecount += 1
if self.batch_size > 1: if self.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png' filename = f'{basecount:06}.{seed}.01.png'
else: else:
filename = f'{basecount:06}.{seed}.png' filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir,filename) return os.path.join(self.outdir, filename)
else: else:
basename = os.path.basename(previouspath) basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename) x = re.match('^(\d+)\..*\.png', basename)
if not x: if not x:
return self.unique_filename(seed,previouspath) return self.unique_filename(seed, previouspath)
basecount = int(x.groups()[0]) basecount = int(x.groups()[0])
series = 0 series = 0
finished = False finished = False
while not finished: while not finished:
series += 1 series += 1
filename = f'{basecount:06}.{seed}.png' filename = f'{basecount:06}.{seed}.png'
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): if self.batch_size > 1 or os.path.exists(
os.path.join(self.outdir, filename)
):
filename = f'{basecount:06}.{seed}.{series:02}.png' filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(self.outdir,filename)) finished = not os.path.exists(
return os.path.join(self.outdir,filename) os.path.join(self.outdir, filename)
)
return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self,image,prompt,path): def save_image_and_prompt_to_png(self, image, prompt, path):
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt) info.add_text('Dream', prompt)
image.save(path,"PNG",pnginfo=info) image.save(path, 'PNG', pnginfo=info)
def make_grid(self,image_list,rows=None,cols=None): def make_grid(self, image_list, rows=None, cols=None):
image_cnt = len(image_list) image_cnt = len(image_list)
if None in (rows,cols): if None in (rows, cols):
rows = floor(sqrt(image_cnt)) # try to make it square rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt/rows) cols = ceil(image_cnt / rows)
width = image_list[0].width width = image_list[0].width
height = image_list[0].height height = image_list[0].height
grid_img = Image.new('RGB',(width*cols,height*rows)) grid_img = Image.new('RGB', (width * cols, height * rows))
for r in range(0,rows): for r in range(0, rows):
for c in range (0,cols): for c in range(0, cols):
i = r*rows + c i = r * rows + c
grid_img.paste(image_list[i],(c*width,r*height)) grid_img.paste(image_list[i], (c * width, r * height))
return grid_img return grid_img
class PromptFormatter():
def __init__(self,t2i,opt): class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i self.t2i = t2i
self.opt = opt self.opt = opt
def normalize_prompt(self): def normalize_prompt(self):
'''Normalize the prompt and switches''' """Normalize the prompt and switches"""
t2i = self.t2i t2i = self.t2i
opt = self.opt opt = self.opt
switches = list() switches = list()
switches.append(f'"{opt.prompt}"') switches.append(f'"{opt.prompt}"')
@ -114,4 +123,3 @@ class PromptFormatter():
if t2i.full_precision: if t2i.full_precision:
switches.append('-F') switches.append('-F')
return ' '.join(switches) return ' '.join(switches)

View File

@ -1,37 +1,40 @@
''' """
Readline helper functions for dream.py (linux and mac only). Readline helper functions for dream.py (linux and mac only).
''' """
import os import os
import re import re
import atexit import atexit
# ---------------readline utilities--------------------- # ---------------readline utilities---------------------
try: try:
import readline import readline
readline_available = True readline_available = True
except: except:
readline_available = False readline_available = False
class Completer():
def __init__(self,options): class Completer:
def __init__(self, options):
self.options = sorted(options) self.options = sorted(options)
return return
def complete(self,text,state): def complete(self, text, state):
buffer = readline.get_line_buffer() buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')): if text.startswith(('-I', '--init_img')):
return self._path_completions(text,state,('.png')) return self._path_completions(text, state, ('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')): if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
return self._path_completions(text,state,()) return self._path_completions(text, state, ())
response = None response = None
if state == 0: if state == 0:
# This is the first time for this text, so build a match list. # This is the first time for this text, so build a match list.
if text: if text:
self.matches = [s self.matches = [
for s in self.options s for s in self.options if s and s.startswith(text)
if s and s.startswith(text)] ]
else: else:
self.matches = self.options[:] self.matches = self.options[:]
@ -43,32 +46,34 @@ class Completer():
response = None response = None
return response return response
def _path_completions(self,text,state,extensions): def _path_completions(self, text, state, extensions):
# get the path so far # get the path so far
if text.startswith('-I'): if text.startswith('-I'):
path = text.replace('-I','',1).lstrip() path = text.replace('-I', '', 1).lstrip()
elif text.startswith('--init_img='): elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip() path = text.replace('--init_img=', '', 1).lstrip()
else: else:
path = text path = text
matches = list() matches = list()
path = os.path.expanduser(path) path = os.path.expanduser(path)
if len(path)==0: if len(path) == 0:
matches.append(text+'./') matches.append(text + './')
else: else:
dir = os.path.dirname(path) dir = os.path.dirname(path)
dir_list = os.listdir(dir) dir_list = os.listdir(dir)
for n in dir_list: for n in dir_list:
if n.startswith('.') and len(n)>1: if n.startswith('.') and len(n) > 1:
continue continue
full_path = os.path.join(dir,n) full_path = os.path.join(dir, n)
if full_path.startswith(path): if full_path.startswith(path):
if os.path.isdir(full_path): if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/') matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions): elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n)) matches.append(os.path.join(os.path.dirname(text), n))
try: try:
response = matches[state] response = matches[state]
@ -76,19 +81,47 @@ class Completer():
response = None response = None
return response return response
if readline_available: if readline_available:
readline.set_completer(Completer(['cd','pwd', readline.set_completer(
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', Completer(
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', [
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete) 'cd',
readline.set_completer_delims(" ") 'pwd',
'--steps',
'-s',
'--seed',
'-S',
'--iterations',
'-n',
'--batch_size',
'-b',
'--width',
'-W',
'--height',
'-H',
'--cfg_scale',
'-C',
'--grid',
'-g',
'--individual',
'-i',
'--init_img',
'-I',
'--strength',
'-f',
'-v',
'--variants',
]
).complete
)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete') readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'),".dream_history") histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try: try:
readline.read_history_file(histfile) readline.read_history_file(histfile)
readline.set_history_length(1000) readline.set_history_length(1000)
except FileNotFoundError: except FileNotFoundError:
pass pass
atexit.register(readline.write_history_file,histfile) atexit.register(readline.write_history_file, histfile)

View File

@ -5,32 +5,49 @@ class LambdaWarmUpCosineScheduler:
""" """
note: use with a base_lr of 1.0 note: use with a base_lr of 1.0
""" """
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start self.lr_start = lr_start
self.lr_min = lr_min self.lr_min = lr_min
self.lr_max = lr_max self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0. self.last_lr = 0.0
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_lr}'
)
if n < self.lr_warm_up_steps: if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr self.last_lr = lr
return lr return lr
else: else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0) t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi)) 1 + np.cos(t * np.pi)
)
self.last_lr = lr self.last_lr = lr
return lr return lr
def __call__(self, n, **kwargs): def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs) return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2: class LambdaWarmUpCosineScheduler2:
@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0. note: use with a base_lr of 1.0.
""" """
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) def __init__(
self,
warm_up_steps,
f_min,
f_max,
f_start,
cycle_lengths,
verbosity_interval=0,
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start self.f_start = f_start
self.f_min = f_min self.f_min = f_min
self.f_max = f_max self.f_max = f_max
self.cycle_lengths = cycle_lengths self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0. self.last_f = 0.0
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
def find_in_interval(self, n): def find_in_interval(self, n):
@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2:
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " if n % self.verbosity_interval == 0:
f"current cycle {cycle}") print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f self.last_f = f
return f return f
else: else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0) t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( f = self.f_min[cycle] + 0.5 * (
1 + np.cos(t * np.pi)) self.f_max[cycle] - self.f_min[cycle]
) * (1 + np.cos(t * np.pi))
self.last_f = f self.last_f = f
return f return f
@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2:
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " if n % self.verbosity_interval == 0:
f"current cycle {cycle}") print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f self.last_f = f
return f return f
else: else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f self.last_f = f
return f return f

View File

@ -6,29 +6,32 @@ from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self, def __init__(
ddconfig, self,
lossconfig, ddconfig,
n_embed, lossconfig,
embed_dim, n_embed,
ckpt_path=None, embed_dim,
ignore_keys=[], ckpt_path=None,
image_key="image", ignore_keys=[],
colorize_nlabels=None, image_key='image',
monitor=None, colorize_nlabels=None,
batch_resize_range=None, monitor=None,
scheduler_config=None, batch_resize_range=None,
lr_g_factor=1.0, scheduler_config=None,
remap=None, lr_g_factor=1.0,
sane_index_shape=False, # tell vector quantizer to return indices as bhw remap=None,
use_ema=False sane_index_shape=False, # tell vector quantizer to return indices as bhw
): use_ema=False,
):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.n_embed = n_embed self.n_embed = n_embed
@ -36,24 +39,34 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, self.quantize = VectorQuantizer(
remap=remap, n_embed,
sane_index_shape=sane_index_shape) embed_dim,
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) beta=0.25,
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels)==int assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.batch_resize_range = batch_resize_range self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@ -66,28 +79,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") print(f'{context}: Switched to EMA weights')
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0: if len(missing) > 0:
print(f"Missing Keys: {missing}") print(f'Missing Keys: {missing}')
print(f"Unexpected Keys: {unexpected}") print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
@ -115,7 +130,7 @@ class VQModel(pl.LightningModule):
return dec return dec
def forward(self, input, return_pred_indices=False): def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input) quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant) dec = self.decode(quant)
if return_pred_indices: if return_pred_indices:
return dec, diff, ind return dec, diff, ind
@ -125,7 +140,11 @@ class VQModel(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0] lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1] upper_size = self.batch_resize_range[1]
@ -133,9 +152,11 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom # do the first few batches with max size to avoid later oom
new_resize = upper_size new_resize = upper_size
else: else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]: if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic") x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach() x = x.detach()
return x return x
@ -147,81 +168,139 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0: if optimizer_idx == 0:
# autoencode # autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="train", qloss,
predicted_indices=ind) x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return aeloss return aeloss
if optimizer_idx == 1: if optimizer_idx == 1:
# discriminator # discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="train") qloss,
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return discloss return discloss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") log_dict_ema = self._validation_step(
batch, batch_idx, suffix='_ema'
)
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True) xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, aeloss, log_dict_ae = self.loss(
self.global_step, qloss,
last_layer=self.get_last_layer(), x,
split="val"+suffix, xrec,
predicted_indices=ind 0,
) self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, discloss, log_dict_disc = self.loss(
self.global_step, qloss,
last_layer=self.get_last_layer(), x,
split="val"+suffix, xrec,
predicted_indices=ind 1,
) self.global_step,
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] last_layer=self.get_last_layer(),
self.log(f"val{suffix}/rec_loss", rec_loss, split='val' + suffix,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) predicted_indices=ind,
self.log(f"val{suffix}/aeloss", aeloss, )
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
if version.parse(pl.__version__) >= version.parse('1.4.0'): if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"] del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
def configure_optimizers(self): def configure_optimizers(self):
lr_d = self.learning_rate lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d) print('lr_d', lr_d)
print("lr_g", lr_g) print('lr_g', lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ opt_ae = torch.optim.Adam(
list(self.decoder.parameters())+ list(self.encoder.parameters())
list(self.quantize.parameters())+ + list(self.decoder.parameters())
list(self.quant_conv.parameters())+ + list(self.quantize.parameters())
list(self.post_quant_conv.parameters()), + list(self.quant_conv.parameters())
lr=lr_g, betas=(0.5, 0.9)) + list(self.post_quant_conv.parameters()),
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr_g,
lr=lr_d, betas=(0.5, 0.9)) betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None: if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...") print('Setting up LambdaLR scheduler...')
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'scheduler': LambdaLR(
opt_ae, lr_lambda=scheduler.schedule
),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1,
}, },
{ {
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'scheduler': LambdaLR(
opt_disc, lr_lambda=scheduler.schedule
),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1,
}, },
] ]
return [opt_ae, opt_disc], scheduler return [opt_ae, opt_disc], scheduler
@ -235,7 +314,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
log["inputs"] = x log['inputs'] = x
return log return log
xrec, _ = self(x) xrec, _ = self(x)
if x.shape[1] > 3: if x.shape[1] > 3:
@ -243,21 +322,24 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log["inputs"] = x log['inputs'] = x
log["reconstructions"] = xrec log['reconstructions'] = xrec
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
log["reconstructions_ema"] = xrec_ema xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == "segmentation" assert self.image_key == 'segmentation'
if not hasattr(self, "colorize"): if not hasattr(self, 'colorize'):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1. x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x
@ -283,43 +365,50 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule): class AutoencoderKL(pl.LightningModule):
def __init__(self, def __init__(
ddconfig, self,
lossconfig, ddconfig,
embed_dim, lossconfig,
ckpt_path=None, embed_dim,
ignore_keys=[], ckpt_path=None,
image_key="image", ignore_keys=[],
colorize_nlabels=None, image_key='image',
monitor=None, colorize_nlabels=None,
): monitor=None,
):
super().__init__() super().__init__()
self.image_key = image_key self.image_key = image_key
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"] assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 2 * ddconfig['z_channels'], 2 * embed_dim, 1
)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels)==int assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
self.load_state_dict(sd, strict=False) self.load_state_dict(sd, strict=False)
print(f"Restored from {path}") print(f'Restored from {path}')
def encode(self, x): def encode(self, x):
h = self.encoder(x) h = self.encoder(x)
@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
return x return x
def training_step(self, batch, batch_idx, optimizer_idx): def training_step(self, batch, batch_idx, optimizer_idx):
@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0: if optimizer_idx == 0:
# train encoder+decoder+logvar # train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="train") inputs,
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) reconstructions,
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return aeloss return aeloss
if optimizer_idx == 1: if optimizer_idx == 1:
# train the discriminator # train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="train") inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log(
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return discloss return discloss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key) inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs) reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="val") inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="val") inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
def configure_optimizers(self): def configure_optimizers(self):
lr = self.learning_rate lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ opt_ae = torch.optim.Adam(
list(self.decoder.parameters())+ list(self.encoder.parameters())
list(self.quant_conv.parameters())+ + list(self.decoder.parameters())
list(self.post_quant_conv.parameters()), + list(self.quant_conv.parameters())
lr=lr, betas=(0.5, 0.9)) + list(self.post_quant_conv.parameters()),
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr,
lr=lr, betas=(0.5, 0.9)) betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], [] return [opt_ae, opt_disc], []
def get_last_layer(self): def get_last_layer(self):
@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample())) log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec log['reconstructions'] = xrec
log["inputs"] = x log['inputs'] = x
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == "segmentation" assert self.image_key == 'segmentation'
if not hasattr(self, "colorize"): if not hasattr(self, 'colorize'):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1. x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x

View File

@ -10,13 +10,13 @@ from einops import rearrange
from glob import glob from glob import glob
from natsort import natsorted from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel from ldm.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = { __models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
def disabled_train(self, mode=True): def disabled_train(self, mode=True):
@ -26,37 +26,49 @@ def disabled_train(self, mode=True):
class NoisyLatentImageClassifier(pl.LightningModule): class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(
def __init__(self, self,
diffusion_path, diffusion_path,
num_classes, num_classes,
ckpt_path=None, ckpt_path=None,
pool='attention', pool='attention',
label_key=None, label_key=None,
diffusion_ckpt_path=None, diffusion_ckpt_path=None,
scheduler_config=None, scheduler_config=None,
weight_decay=1.e-2, weight_decay=1.0e-2,
log_steps=10, log_steps=10,
monitor='val/loss', monitor='val/loss',
*args, *args,
**kwargs): **kwargs,
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.num_classes = num_classes self.num_classes = num_classes
# get latest config of diffusion model # get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] diffusion_config = natsorted(
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion() self.load_diffusion()
self.monitor = monitor self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 self.numd = (
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
)
self.log_time_interval = (
self.diffusion_model.num_timesteps // log_steps
)
self.log_steps = log_steps self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ self.label_key = (
label_key
if not hasattr(self.diffusion_model, 'cond_stage_key')
else self.diffusion_model.cond_stage_key else self.diffusion_model.cond_stage_key
)
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' assert (
self.label_key is not None
), 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__: if self.label_key not in __models__:
raise NotImplementedError() raise NotImplementedError()
@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.weight_decay = weight_decay self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location='cpu')
if "state_dict" in list(sd.keys()): if 'state_dict' in list(sd.keys()):
sd = sd["state_dict"] sd = sd['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = (
sd, strict=False) self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0: if len(missing) > 0:
print(f"Missing Keys: {missing}") print(f'Missing Keys: {missing}')
if len(unexpected) > 0: if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}") print(f'Unexpected Keys: {unexpected}')
def load_diffusion(self): def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config) model = instantiate_from_config(self.diffusion_config)
@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule):
param.requires_grad = False param.requires_grad = False
def load_classifier(self, ckpt_path, pool): def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params) model_config = deepcopy(
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels self.diffusion_config.params.unet_config.params
)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes model_config.out_channels = self.num_classes
if self.label_key == 'class_label': if self.label_key == 'class_label':
model_config.pool = pool model_config.pool = pool
self.model = __models__[self.label_key](**model_config) self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None: if ckpt_path is not None:
print('#####################################################################') print(
'#####################################################################'
)
print(f'load from ckpt "{ckpt_path}"') print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################') print(
'#####################################################################'
)
self.init_from_ckpt(ckpt_path) self.init_from_ckpt(ckpt_path)
@torch.no_grad() @torch.no_grad()
@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule):
noise = default(noise, lambda: torch.randn_like(x)) noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise: if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(
x.shape[0], t + 1
)
)
# todo: make sure t+1 is correct here # todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, return self.diffusion_model.q_sample(
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs): def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t) return self.model(x_noisy, t)
@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule):
targets = rearrange(targets, 'b h w c -> b c h w') targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd): for down in range(self.numd):
h, w = targets.shape[-2:] h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') targets = F.interpolate(
targets, size=(h // 2, w // 2), mode='nearest'
)
# targets = rearrange(targets,'b c h w -> b h w c') # targets = rearrange(targets,'b c h w -> b h w c')
return targets return targets
def compute_top_k(self, logits, labels, k, reduction="mean"): def compute_top_k(self, logits, labels, k, reduction='mean'):
_, top_ks = torch.topk(logits, k, dim=1) _, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean": if reduction == 'mean':
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() return (
elif reduction == "none": (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
)
elif reduction == 'none':
return (top_ks == labels[:, None]).float().sum(dim=-1) return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self): def on_train_epoch_start(self):
@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def write_logs(self, loss, logits, targets): def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val' log_prefix = 'train' if self.training else 'val'
log = {} log = {}
log[f"{log_prefix}/loss"] = loss.mean() log[f'{log_prefix}/loss'] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k( log[f'{log_prefix}/acc@1'] = self.compute_top_k(
logits, targets, k=1, reduction="mean" logits, targets, k=1, reduction='mean'
) )
log[f"{log_prefix}/acc@5"] = self.compute_top_k( log[f'{log_prefix}/acc@5'] = self.compute_top_k(
logits, targets, k=5, reduction="mean" logits, targets, k=5, reduction='mean'
) )
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) self.log_dict(
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) log,
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
self.log(
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
)
self.log(
'global_step',
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
lr = self.optimizers().param_groups[0]['lr'] lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) self.log(
'lr_abs',
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None): def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch) targets = self.get_conditioning(batch)
if targets.dim() == 4: if targets.dim() == 4:
targets = targets.argmax(dim=1) targets = targets.argmax(dim=1)
if t is None: if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else: else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() t = torch.full(
size=(x.shape[0],), fill_value=t, device=self.device
).long()
x_noisy = self.get_x_noisy(x, t) x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t) logits = self(x_noisy, t)
@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
return loss return loss
def reset_noise_accs(self): def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in self.noisy_acc = {
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} t: {'acc@1': [], 'acc@5': []}
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self): def on_validation_start(self):
self.reset_noise_accs() self.reset_noise_accs()
@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule):
for t in self.noisy_acc: for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t) _, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) self.noisy_acc[t]['acc@1'].append(
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) self.compute_top_k(logits, targets, k=1, reduction='mean')
)
self.noisy_acc[t]['acc@5'].append(
self.compute_top_k(logits, targets, k=5, reduction='mean')
)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler: if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...") print('Setting up LambdaLR scheduler...')
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 'scheduler': LambdaLR(
optimizer, lr_lambda=scheduler.schedule
),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1,
}] }
]
return [optimizer], scheduler return [optimizer], scheduler
return optimizer return optimizer
@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
y = self.get_conditioning(batch) y = self.get_conditioning(batch)
if self.label_key == 'class_label': if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
log['labels'] = y log['labels'] = y
if ismap(y): if ismap(y):
@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
log[f'inputs@t{current_time}'] = x_noisy log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) pred = F.one_hot(
logits.argmax(dim=1), num_classes=self.num_classes
)
pred = rearrange(pred, 'b h w c -> b c h w') pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
pred
)
for key in log: for key in log:
log[key] = log[key][:N] log[key] = log[key][:N]

View File

@ -5,12 +5,16 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ from ldm.modules.diffusionmodules.util import (
extract_into_tensor make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object): class DDIMSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs): def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
@ -23,70 +27,122 @@ class DDIMSampler(object):
attr = attr.to(torch.device(self.device)) attr = attr.to(torch.device(self.device))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, self,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' assert (
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) )
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters # ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), (
ddim_timesteps=self.ddim_timesteps, ddim_sigmas,
eta=ddim_eta,verbose=verbose) ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( (1 - self.alphas_cumprod_prev)
1 - self.alphas_cumprod / self.alphas_cumprod_prev)) / (1 - self.alphas_cumprod)
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
S, self,
batch_size, S,
shape, batch_size,
conditioning=None, shape,
callback=None, conditioning=None,
normals_sequence=None, callback=None,
img_callback=None, normals_sequence=None,
quantize_x0=False, img_callback=None,
eta=0., quantize_x0=False,
mask=None, eta=0.0,
x0=None, mask=None,
temperature=1., x0=None,
noise_dropout=0., temperature=1.0,
score_corrector=None, noise_dropout=0.0,
corrector_kwargs=None, score_corrector=None,
verbose=True, corrector_kwargs=None,
x_T=None, verbose=True,
log_every_t=100, x_T=None,
unconditional_guidance_scale=1., log_every_t=100,
unconditional_conditioning=None, unconditional_guidance_scale=1.0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... unconditional_conditioning=None,
**kwargs # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
): **kwargs,
):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling # sampling
@ -94,30 +150,47 @@ class DDIMSampler(object):
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}') print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size, samples, intermediates = self.ddim_sampling(
callback=callback, conditioning,
img_callback=img_callback, size,
quantize_denoised=quantize_x0, callback=callback,
mask=mask, x0=x0, img_callback=img_callback,
ddim_use_original_steps=False, quantize_denoised=quantize_x0,
noise_dropout=noise_dropout, mask=mask,
temperature=temperature, x0=x0,
score_corrector=score_corrector, ddim_use_original_steps=False,
corrector_kwargs=corrector_kwargs, noise_dropout=noise_dropout,
x_T=x_T, temperature=temperature,
log_every_t=log_every_t, score_corrector=score_corrector,
unconditional_guidance_scale=unconditional_guidance_scale, corrector_kwargs=corrector_kwargs,
unconditional_conditioning=unconditional_conditioning, x_T=x_T,
) log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def ddim_sampling(self, cond, shape, def ddim_sampling(
x_T=None, ddim_use_original_steps=False, self,
callback=None, timesteps=None, quantize_denoised=False, cond,
mask=None, x0=None, img_callback=None, log_every_t=100, shape,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -126,17 +199,38 @@ class DDIMSampler(object):
img = x_T img = x_T
if timesteps is None: if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps: elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end] timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) time_range = (
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] reversed(range(0, timesteps))
print(f"Running DDIM Sampling with {total_steps} timesteps") if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True) iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
@ -144,18 +238,30 @@ class DDIMSampler(object):
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.model.q_sample(
img = img_orig * mask + (1. - mask) * img x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, outs = self.p_sample_ddim(
quantize_denoised=quantize_denoised, temperature=temperature, img,
noise_dropout=noise_dropout, score_corrector=score_corrector, cond,
corrector_kwargs=corrector_kwargs, ts,
unconditional_guidance_scale=unconditional_guidance_scale, index=index,
unconditional_conditioning=unconditional_conditioning) use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback:
if img_callback: img_callback(pred_x0, i) callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates['x_inter'].append(img)
@ -164,42 +270,82 @@ class DDIMSampler(object):
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, self,
unconditional_guidance_scale=1., unconditional_conditioning=None): x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == "eps" assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = (
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev self.model.alphas_cumprod
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas if use_original_steps
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = (
if noise_dropout > 0.: sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
@ -217,26 +363,51 @@ class DDIMSampler(object):
if noise is None: if noise is None:
noise = torch.randn_like(x0) noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + return (
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad() @torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def decode(
use_original_steps=False): self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps") print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent x_dec = x_latent
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) ts = torch.full(
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, (x_latent.shape[0],),
unconditional_guidance_scale=unconditional_guidance_scale, step,
unconditional_conditioning=unconditional_conditioning) device=x_latent.device,
dtype=torch.long,
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return x_dec return x_dec

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,9 @@
'''wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers''' """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K import k_diffusion as K
import torch import torch
import torch.nn as nn import torch.nn as nn
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
@ -15,8 +16,9 @@ class CFGDenoiser(nn.Module):
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class KSampler(object): class KSampler(object):
def __init__(self, model, schedule="lms", device="cuda", **kwargs): def __init__(self, model, schedule='lms', device='cuda', **kwargs):
super().__init__() super().__init__()
self.model = K.external.CompVisDenoiser(model) self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule self.schedule = schedule
@ -26,44 +28,57 @@ class KSampler(object):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with # most of these arguments are ignored and are only present for compatibility with
# other samples # other samples
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
S, self,
batch_size, S,
shape, batch_size,
conditioning=None, shape,
callback=None, conditioning=None,
normals_sequence=None, callback=None,
img_callback=None, normals_sequence=None,
quantize_x0=False, img_callback=None,
eta=0., quantize_x0=False,
mask=None, eta=0.0,
x0=None, mask=None,
temperature=1., x0=None,
noise_dropout=0., temperature=1.0,
score_corrector=None, noise_dropout=0.0,
corrector_kwargs=None, score_corrector=None,
verbose=True, corrector_kwargs=None,
x_T=None, verbose=True,
log_every_t=100, x_T=None,
unconditional_guidance_scale=1., log_every_t=100,
unconditional_conditioning=None, unconditional_guidance_scale=1.0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... unconditional_conditioning=None,
**kwargs # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
): **kwargs,
):
sigmas = self.model.get_sigmas(S) sigmas = self.model.get_sigmas(S)
if x_T: if x_T:
x = x_T x = x_T
else: else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model) model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} extra_args = {
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args), 'cond': conditioning,
None) 'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
),
None,
)

View File

@ -5,11 +5,15 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object): class PLMSSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs): def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
@ -23,103 +27,172 @@ class PLMSSampler(object):
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0: if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS') raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, self.ddim_timesteps = make_ddim_timesteps(
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' assert (
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) )
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer(
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters # ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), (
ddim_timesteps=self.ddim_timesteps, ddim_sigmas,
eta=ddim_eta,verbose=verbose) ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( (1 - self.alphas_cumprod_prev)
1 - self.alphas_cumprod / self.alphas_cumprod_prev)) / (1 - self.alphas_cumprod)
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
S, self,
batch_size, S,
shape, batch_size,
conditioning=None, shape,
callback=None, conditioning=None,
normals_sequence=None, callback=None,
img_callback=None, normals_sequence=None,
quantize_x0=False, img_callback=None,
eta=0., quantize_x0=False,
mask=None, eta=0.0,
x0=None, mask=None,
temperature=1., x0=None,
noise_dropout=0., temperature=1.0,
score_corrector=None, noise_dropout=0.0,
corrector_kwargs=None, score_corrector=None,
verbose=True, corrector_kwargs=None,
x_T=None, verbose=True,
log_every_t=100, x_T=None,
unconditional_guidance_scale=1., log_every_t=100,
unconditional_conditioning=None, unconditional_guidance_scale=1.0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... unconditional_conditioning=None,
**kwargs # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
): **kwargs,
):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling # sampling
C, H, W = shape C, H, W = shape
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}') # print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size, samples, intermediates = self.plms_sampling(
callback=callback, conditioning,
img_callback=img_callback, size,
quantize_denoised=quantize_x0, callback=callback,
mask=mask, x0=x0, img_callback=img_callback,
ddim_use_original_steps=False, quantize_denoised=quantize_x0,
noise_dropout=noise_dropout, mask=mask,
temperature=temperature, x0=x0,
score_corrector=score_corrector, ddim_use_original_steps=False,
corrector_kwargs=corrector_kwargs, noise_dropout=noise_dropout,
x_T=x_T, temperature=temperature,
log_every_t=log_every_t, score_corrector=score_corrector,
unconditional_guidance_scale=unconditional_guidance_scale, corrector_kwargs=corrector_kwargs,
unconditional_conditioning=unconditional_conditioning, x_T=x_T,
) log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def plms_sampling(self, cond, shape, def plms_sampling(
x_T=None, ddim_use_original_steps=False, self,
callback=None, timesteps=None, quantize_denoised=False, cond,
mask=None, x0=None, img_callback=None, log_every_t=100, shape,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -128,42 +201,81 @@ class PLMSSampler(object):
img = x_T img = x_T
if timesteps is None: if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps: elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end] timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) time_range = (
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] list(reversed(range(0, timesteps)))
# print(f"Running PLMS Sampling with {total_steps} timesteps") if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True) iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = [] old_eps = []
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long) ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.model.q_sample(
img = img_orig * mask + (1. - mask) * img x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, outs = self.p_sample_plms(
quantize_denoised=quantize_denoised, temperature=temperature, img,
noise_dropout=noise_dropout, score_corrector=score_corrector, cond,
corrector_kwargs=corrector_kwargs, ts,
unconditional_guidance_scale=unconditional_guidance_scale, index=index,
unconditional_conditioning=unconditional_conditioning, use_original_steps=ddim_use_original_steps,
old_eps=old_eps, t_next=ts_next) quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs img, pred_x0, e_t = outs
old_eps.append(e_t) old_eps.append(e_t)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)
if callback: callback(i) if callback:
if img_callback: img_callback(pred_x0, i) callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates['x_inter'].append(img)
@ -172,47 +284,95 @@ class PLMSSampler(object):
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, self,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t_uncond, e_t = self.model.apply_model(
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == "eps" assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = (
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev self.model.alphas_cumprod
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas if use_original_steps
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) a_prev = torch.full(
(b, 1, 1, 1), alphas_prev[index], device=device
)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = (
if noise_dropout > 0.: sigma_t
* noise_like(x.shape, device, repeat_noise)
* temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
@ -231,7 +391,12 @@ class PLMSSampler(object):
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3: elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 e_t_prime = (
55 * e_t
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

View File

@ -13,7 +13,7 @@ def exists(val):
def uniq(arr): def uniq(arr):
return{el: True for el in arr}.keys() return {el: True for el in arr}.keys()
def default(val, d): def default(val, d):
@ -45,19 +45,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = (
nn.Linear(dim, inner_dim), nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
nn.GELU() if not glu
) if not glu else GEGLU(dim, inner_dim) else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
) )
def forward(self, x): def forward(self, x):
@ -74,7 +73,9 @@ def zero_module(module):
def Normalize(in_channels): def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module): class LinearAttention(nn.Module):
@ -82,17 +83,28 @@ class LinearAttention(nn.Module):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) q, k, v = rearrange(
k = k.softmax(dim=-1) qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3,
)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w,
)
return self.to_out(out) return self.to_out(out)
@ -102,26 +114,18 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
kernel_size=1, )
stride=1, self.k = torch.nn.Conv2d(
padding=0) in_channels, in_channels, kernel_size=1, stride=1, padding=0
self.k = torch.nn.Conv2d(in_channels, )
in_channels, self.v = torch.nn.Conv2d(
kernel_size=1, in_channels, in_channels, kernel_size=1, stride=1, padding=0
stride=1, )
padding=0) self.proj_out = torch.nn.Conv2d(
self.v = torch.nn.Conv2d(in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels, )
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -131,12 +135,12 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_) v = self.v(h_)
# compute attention # compute attention
b,c,h,w = q.shape b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c') q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)') k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k) w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values # attend to values
@ -146,16 +150,18 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x + h_
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@ -163,8 +169,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
@ -175,7 +180,9 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)
)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@ -194,19 +201,37 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x)) + x
@ -223,29 +248,43 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(
inner_dim, in_channels, inner_dim, kernel_size=1, stride=1, padding=0
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
) )
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.transformer_blocks = nn.ModuleList(
in_channels, [
kernel_size=1, BasicTransformerBlock(
stride=1, inner_dim,
padding=0)) n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
)
)
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
@ -258,4 +297,4 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context) x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer
def convert_module_to_f16(x): def convert_module_to_f16(x):
pass pass
def convert_module_to_f32(x): def convert_module_to_f32(x):
pass pass
@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module):
output_dim: int = None, output_dim: int = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels self.num_heads = embed_dim // num_heads_channels
@ -97,37 +100,45 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
) )
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
class TransposedUpsample(nn.Module): class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding' """Learned 2x upsampling without padding"""
def __init__(self, channels, out_channels=None, ks=5): def __init__(self, channels, out_channels=None, ks=5):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
def forward(self,x): def forward(self, x):
return self.up(x) return self.up(x)
@ -140,7 +151,9 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -149,7 +162,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
) )
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
@ -219,7 +237,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, 2 * self.out_channels
if use_scale_shift_norm
else self.out_channels,
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
@ -227,7 +247,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1
)
), ),
) )
@ -238,7 +260,9 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, 3, padding=1 dims, channels, self.out_channels, 3, padding=1
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1
)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -251,7 +275,6 @@ class ResBlock(TimestepBlock):
self._forward, (x, emb), self.parameters(), self.use_checkpoint self._forward, (x, emb), self.parameters(), self.use_checkpoint
) )
def _forward(self, x, emb): def _forward(self, x, emb):
if self.updown: if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@ -297,7 +320,7 @@ class AttentionBlock(nn.Module):
else: else:
assert ( assert (
channels % num_head_channels == 0 channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels) self.norm = normalization(channels)
@ -312,8 +335,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x): def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! return checkpoint(
#return pt_checkpoint(self._forward, x) # pytorch self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x): def _forward(self, x):
b, c, *spatial = x.shape b, c, *spatial = x.shape
@ -340,7 +365,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops. # We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes # The first computes the weight matrix, the second computes
# the combination of the value vectors. # the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops]) model.total_ops += th.DoubleTensor([matmul_ops])
@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0 assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1
)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale 'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v) a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@ -397,12 +424,14 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1) q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
"bct,bcs->bts", 'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length), (q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) a = th.einsum(
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@ -461,19 +490,24 @@ class UNetModel(nn.Module):
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True, legacy=True,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert (
context_dim is not None
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert (
use_spatial_transformer
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig: if type(context_dim) == ListConfig:
context_dim = list(context_dim) context_dim = list(context_dim)
@ -481,10 +515,14 @@ class UNetModel(nn.Module):
num_heads_upsample = num_heads num_heads_upsample = num_heads
if num_heads == -1: if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' assert (
num_head_channels != -1
), 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert (
num_heads != -1
), 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
@ -545,8 +583,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@ -554,8 +596,14 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -592,8 +640,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
ch, ch,
@ -609,9 +661,15 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
), else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResBlock( ResBlock(
ch, ch,
time_embed_dim, time_embed_dim,
@ -646,8 +704,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@ -655,8 +717,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
) )
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
@ -673,7 +741,9 @@ class UNetModel(nn.Module):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -682,14 +752,16 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)
),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
normalization(ch), normalization(ch),
conv_nd(dims, model_channels, n_embed, 1), conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@ -707,7 +779,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@ -718,9 +790,11 @@ class UNetModel(nn.Module):
""" """
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), 'must specify y if and only if the model is class-conditional'
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
@ -768,9 +842,9 @@ class EncoderUNetModel(nn.Module):
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
pool="adaptive", pool='adaptive',
*args, *args,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
@ -888,7 +962,7 @@ class EncoderUNetModel(nn.Module):
) )
self._feature_size += ch self._feature_size += ch
self.pool = pool self.pool = pool
if pool == "adaptive": if pool == 'adaptive':
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
@ -896,7 +970,7 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(), nn.Flatten(),
) )
elif pool == "attention": elif pool == 'attention':
assert num_head_channels != -1 assert num_head_channels != -1
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
@ -905,13 +979,13 @@ class EncoderUNetModel(nn.Module):
(image_size // ds), ch, num_head_channels, out_channels (image_size // ds), ch, num_head_channels, out_channels
), ),
) )
elif pool == "spatial": elif pool == 'spatial':
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
nn.ReLU(), nn.ReLU(),
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
elif pool == "spatial_v2": elif pool == 'spatial_v2':
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
normalization(2048), normalization(2048),
@ -919,7 +993,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
else: else:
raise NotImplementedError(f"Unexpected {pool} pooling") raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@ -942,20 +1016,21 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
if self.pool.startswith("spatial"): if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb) h = self.middle_block(h, emb)
if self.pool.startswith("spatial"): if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1) h = th.cat(results, axis=-1)
return self.out(h) return self.out(h)
else: else:
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) return self.out(h)

View File

@ -18,15 +18,24 @@ from einops import repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(
if schedule == "linear": schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == 'linear':
betas = ( betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64,
)
** 2
) )
elif schedule == "cosine": elif schedule == 'cosine':
timesteps = ( timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s
) )
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
@ -34,23 +43,41 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear": elif schedule == 'sqrt_linear':
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(
elif schedule == "sqrt": linear_start, linear_end, n_timestep, dtype=torch.float64
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 )
elif schedule == 'sqrt':
betas = (
torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
** 0.5
)
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == 'uniform': if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = (
(
np.linspace(
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
)
)
** 2
).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps # assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
@ -60,17 +87,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): def make_ddim_sampling_parameters(
alphacums, ddim_timesteps, eta, verbose=True
):
# select alphas for computing the variance schedule # select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) alphas_prev = np.asarray(
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
)
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose: if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') print(
print(f'For the chosen value of eta, which is {eta}, ' f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
f'this results in the following sigma_t schedule for ddim sampler {sigmas}') )
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev
@ -109,7 +146,9 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments. explicitly take as arguments.
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if False: # disabled checkpointing to allow requires_grad = False for main model if (
False
): # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params) args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args) return CheckpointFunction.apply(func, len(inputs), *args)
else: else:
@ -129,7 +168,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad(): with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d # Tensor storage in place, which is not allowed for detach()'d
@ -160,12 +201,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half -math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device) ).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else: else:
embedding = repeat(timesteps, 'b -> b d', d=dim) embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding return embedding
@ -215,6 +260,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
@ -225,7 +271,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs) return nn.Conv2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs): def linear(*args, **kwargs):
@ -245,15 +291,16 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs) return nn.AvgPool2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.AvgPool3d(*args, **kwargs) return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module): class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config): def __init__(self, c_concat_config, c_crossattn_config):
super().__init__() super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config) self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config
)
def forward(self, c_concat, c_crossattn): def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat) c_concat = self.concat_conditioner(c_concat)
@ -262,6 +309,8 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()

View File

@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar) self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar) self.var = torch.exp(self.logvar)
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self): def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x return x
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(
+ self.var - 1.0 - self.logvar, torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3]) dim=[1, 2, 3],
)
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, logtwopi
dim=dims) + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self): def mode(self):
return self.mean return self.mean
@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor = obj tensor = obj
break break
assert tensor is not None, "at least one argument must be a Tensor" assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp(). # Tensors, but it does not work for torch.exp().

View File

@ -10,24 +10,30 @@ class LitEma(nn.Module):
self.m_name2s_name = {} self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates self.register_buffer(
else torch.tensor(-1,dtype=torch.int)) 'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if p.requires_grad: if p.requires_grad:
#remove as '.'-character is not allowed in buffers # remove as '.'-character is not allowed in buffers
s_name = name.replace('.','') s_name = name.replace('.', '')
self.m_name2s_name.update({name:s_name}) self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name,p.clone().detach().data) self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = [] self.collected_params = []
def forward(self,model): def forward(self, model):
decay = self.decay decay = self.decay
if self.num_updates >= 0: if self.num_updates >= 0:
self.num_updates += 1 self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) decay = min(
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay one_minus_decay = 1.0 - decay
@ -38,8 +44,12 @@ class LitEma(nn.Module):
for key in m_param: for key in m_param:
if m_param[key].requires_grad: if m_param[key].requires_grad:
sname = self.m_name2s_name[key] sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname] = shadow_params[sname].type_as(
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) m_param[key]
)
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else: else:
assert not key in self.m_name2s_name assert not key in self.m_name2s_name
@ -48,7 +58,9 @@ class LitEma(nn.Module):
shadow_params = dict(self.named_buffers()) shadow_params = dict(self.named_buffers())
for key in m_param: for key in m_param:
if m_param[key].requires_grad: if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data
)
else: else:
assert not key in self.m_name2s_name assert not key in self.m_name2s_name

View File

@ -8,18 +8,29 @@ from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from functools import partial from functools import partial
DEFAULT_PLACEHOLDER_TOKEN = ["*"] DEFAULT_PLACEHOLDER_TOKEN = ['*']
PROGRESSIVE_SCALE = 2000 PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string): def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, batch_encoding = tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") string,
tokens = batch_encoding["input_ids"] truncation=True,
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids']
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1] return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string): def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string) token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
@ -28,42 +39,54 @@ def get_bert_token_for_string(tokenizer, string):
return token return token
def get_embedding_for_clip_token(embedder, token): def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0] return embedder(token.unsqueeze(0))[0, 0]
class EmbeddingManager(nn.Module): class EmbeddingManager(nn.Module):
def __init__( def __init__(
self, self,
embedder, embedder,
placeholder_strings=None, placeholder_strings=None,
initializer_words=None, initializer_words=None,
per_image_tokens=False, per_image_tokens=False,
num_vectors_per_token=1, num_vectors_per_token=1,
progressive_words=False, progressive_words=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.string_to_token_dict = {} self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict() self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized self.initial_embeddings = (
nn.ParameterDict()
) # These should not be optimized
self.progressive_words = progressive_words self.progressive_words = progressive_words
self.progressive_counter = 0 self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token self.max_vectors_per_token = num_vectors_per_token
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder if hasattr(
embedder, 'tokenizer'
): # using Stable Diffusion's CLIP encoder
self.is_clip = True self.is_clip = True
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) get_token_for_string = partial(
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) get_clip_token_for_string, embedder.tokenizer
)
get_embedding_for_tkn = partial(
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280 token_dim = 1280
else: # using LDM's BERT encoder else: # using LDM's BERT encoder
self.is_clip = False self.is_clip = False
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) get_token_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn
)
get_embedding_for_tkn = embedder.transformer.token_emb get_embedding_for_tkn = embedder.transformer.token_emb
token_dim = 1280 token_dim = 1280
@ -71,79 +94,140 @@ class EmbeddingManager(nn.Module):
placeholder_strings.extend(per_img_token_list) placeholder_strings.extend(per_img_token_list)
for idx, placeholder_string in enumerate(placeholder_strings): for idx, placeholder_string in enumerate(placeholder_strings):
token = get_token_for_string(placeholder_string) token = get_token_for_string(placeholder_string)
if initializer_words and idx < len(initializer_words): if initializer_words and idx < len(initializer_words):
init_word_token = get_token_for_string(initializer_words[idx]) init_word_token = get_token_for_string(initializer_words[idx])
with torch.no_grad(): with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) init_word_embedding = get_embedding_for_tkn(
init_word_token.cpu()
)
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) token_params = torch.nn.Parameter(
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=True,
)
self.initial_embeddings[
placeholder_string
] = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=False,
)
else: else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) token_params = torch.nn.Parameter(
torch.rand(
size=(num_vectors_per_token, token_dim),
requires_grad=True,
)
)
self.string_to_token_dict[placeholder_string] = token self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params self.string_to_param_dict[placeholder_string] = token_params
def forward( def forward(
self, self,
tokenized_text, tokenized_text,
embedded_text, embedded_text,
): ):
b, n, device = *tokenized_text.shape, tokenized_text.device b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items(): for (
placeholder_string,
placeholder_token,
) in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) placeholder_embedding = self.string_to_param_dict[
placeholder_string
].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement if (
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) self.max_vectors_per_token == 1
): # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(
tokenized_text == placeholder_token.to(device)
)
embedded_text[placeholder_idx] = placeholder_embedding embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words: if self.progressive_words:
self.progressive_counter += 1 self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE max_step_tokens = (
1 + self.progressive_counter // PROGRESSIVE_SCALE
)
else: else:
max_step_tokens = self.max_vectors_per_token max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) num_vectors_for_token = min(
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if placeholder_rows.nelement() == 0: if placeholder_rows.nelement() == 0:
continue continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) sorted_cols, sort_idx = torch.sort(
placeholder_cols, descending=True
)
sorted_rows = placeholder_rows[sort_idx] sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)): for idx in range(len(sorted_rows)):
row = sorted_rows[idx] row = sorted_rows[idx]
col = sorted_cols[idx] col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] new_token_row = torch.cat(
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] [
tokenized_text[row][:col],
placeholder_token.repeat(num_vectors_for_token).to(
device
),
tokenized_text[row][col + 1 :],
],
axis=0,
)[:n]
new_embed_row = torch.cat(
[
embedded_text[row][:col],
placeholder_embedding[:num_vectors_for_token],
embedded_text[row][col + 1 :],
],
axis=0,
)[:n]
embedded_text[row] = new_embed_row embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row tokenized_text[row] = new_token_row
return embedded_text return embedded_text
def save(self, ckpt_path): def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict, torch.save(
"string_to_param": self.string_to_param_dict}, ckpt_path) {
'string_to_token': self.string_to_token_dict,
'string_to_param': self.string_to_param_dict,
},
ckpt_path,
)
def load(self, ckpt_path): def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu') ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt["string_to_token"] self.string_to_token_dict = ckpt['string_to_token']
self.string_to_param_dict = ckpt["string_to_param"] self.string_to_param_dict = ckpt['string_to_param']
def get_embedding_norms_squared(self): def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim all_params = torch.cat(
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders list(self.string_to_param_dict.values()), axis=0
) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(
axis=-1
) # num_placeholders
return param_norm_squared return param_norm_squared
@ -151,14 +235,19 @@ class EmbeddingManager(nn.Module):
return self.string_to_param_dict.parameters() return self.string_to_param_dict.parameters()
def embedding_to_coarse_loss(self): def embedding_to_coarse_loss(self):
loss = 0. loss = 0.0
num_embeddings = len(self.initial_embeddings) num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings: for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key] optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device) coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings loss = (
loss
+ (optimized - coarse)
@ (optimized - coarse).T
/ num_embeddings
)
return loss return loss

View File

@ -6,29 +6,39 @@ from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import kornia import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len = None):
def _expand_mask(mask, dtype, tgt_len=None):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = mask.size() bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype): def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
def __init__(self): def __init__(self):
@ -38,7 +48,6 @@ class AbstractEncoder(nn.Module):
raise NotImplementedError raise NotImplementedError
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'): def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__() super().__init__()
@ -56,11 +65,17 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder): class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers""" """Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(
self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda'
):
super().__init__() super().__init__()
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(
attn_layers=Encoder(dim=n_embed, depth=n_layer)) num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens): def forward(self, tokens):
tokens = tokens.to(self.device) # meh tokens = tokens.to(self.device) # meh
@ -72,27 +87,42 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder): class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__() super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements from transformers import (
BertTokenizerFast,
) # TODO: add to reuquirements
# Modified to allow to run on non-internet connected compute nodes. # Modified to allow to run on non-internet connected compute nodes.
# Model needs to be loaded into cache from an internet-connected machine # Model needs to be loaded into cache from an internet-connected machine
# by running: # by running:
# from transformers import BertTokenizerFast # from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased") # BertTokenizerFast.from_pretrained("bert-base-uncased")
try: try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True) self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True
)
except OSError: except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.") raise SystemExit(
"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine."
)
self.device = device self.device = device
self.vq_interface = vq_interface self.vq_interface = vq_interface
self.max_length = max_length self.max_length = max_length
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
return tokens return tokens
@torch.no_grad() @torch.no_grad()
@ -108,53 +138,84 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder): class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers""" """Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0): def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__() super().__init__()
self.use_tknz_fn = use_tokenizer self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn: if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len
)
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(
attn_layers=Encoder(dim=n_embed, depth=n_layer), num_tokens=vocab_size,
emb_dropout=embedding_dropout) max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout,
)
def forward(self, text, embedding_manager=None): def forward(self, text, embedding_manager=None):
if self.use_tknz_fn: if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device) tokens = self.tknz_fn(text) # .to(self.device)
else: else:
tokens = text tokens = text
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager) z = self.transformer(
tokens, return_embeddings=True, embedding_manager=embedding_manager
)
return z return z
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
# output of length 77 # output of length 77
return self(text, **kwargs) return self(text, **kwargs)
class SpatialRescaler(nn.Module): class SpatialRescaler(nn.Module):
def __init__(self, def __init__(
n_stages=1, self,
method='bilinear', n_stages=1,
multiplier=0.5, method='bilinear',
in_channels=3, multiplier=0.5,
out_channels=None, in_channels=3,
bias=False): out_channels=None,
bias=False,
):
super().__init__() super().__init__()
self.n_stages = n_stages self.n_stages = n_stages
assert self.n_stages >= 0 assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] assert method in [
'nearest',
'linear',
'bilinear',
'trilinear',
'bicubic',
'area',
]
self.multiplier = multiplier self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.interpolator = partial(
torch.nn.functional.interpolate, mode=method
)
self.remap_output = out_channels is not None self.remap_output = out_channels is not None
if self.remap_output: if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') print(
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias
)
def forward(self,x): def forward(self, x):
for stage in range(self.n_stages): for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier) x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output: if self.remap_output:
x = self.channel_mapper(x) x = self.channel_mapper(x)
return x return x
@ -162,57 +223,83 @@ class SpatialRescaler(nn.Module):
def encode(self, x): def encode(self, x):
return self(x) return self(x)
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77,
):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=True) self.tokenizer = CLIPTokenizer.from_pretrained(
self.transformer = CLIPTextModel.from_pretrained(version,local_files_only=True) version, local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True
)
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.freeze() self.freeze()
def embedding_forward( def embedding_forward(
self, self,
input_ids = None, input_ids=None,
position_ids = None, position_ids=None,
inputs_embeds = None, inputs_embeds=None,
embedding_manager = None, embedding_manager=None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids) inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None: if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds) inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
position_embeddings = self.position_embedding(position_ids) return embeddings
embeddings = inputs_embeds + position_embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) self.transformer.text_model.embeddings.forward = (
embedding_forward.__get__(self.transformer.text_model.embeddings)
)
def encoder_forward( def encoder_forward(
self, self,
inputs_embeds, inputs_embeds,
attention_mask = None, attention_mask=None,
causal_attention_mask = None, causal_attention_mask=None,
output_attentions = None, output_attentions=None,
output_hidden_states = None, output_hidden_states=None,
return_dict = None, return_dict=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
@ -239,44 +326,61 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return hidden_states return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward( def text_encoder_forward(
self, self,
input_ids = None, input_ids=None,
attention_mask = None, attention_mask=None,
position_ids = None, position_ids=None,
output_attentions = None, output_attentions=None,
output_hidden_states = None, output_hidden_states=None,
return_dict = None, return_dict=None,
embedding_manager = None, embedding_manager=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None: if input_ids is None:
raise ValueError("You have to specify either input_ids") raise ValueError('You have to specify either input_ids')
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( causal_attention_mask = _build_causal_attention_mask(
hidden_states.device bsz, seq_len, hidden_states.dtype
) ).to(hidden_states.device)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _expand_mask(
attention_mask, hidden_states.dtype
)
last_hidden_state = self.encoder( last_hidden_state = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
@ -291,17 +395,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return last_hidden_state return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward( def transformer_forward(
self, self,
input_ids = None, input_ids=None,
attention_mask = None, attention_mask=None,
position_ids = None, position_ids=None,
output_attentions = None, output_attentions=None,
output_hidden_states = None, output_hidden_states=None,
return_dict = None, return_dict=None,
embedding_manager = None, embedding_manager=None,
): ):
return self.text_model( return self.text_model(
input_ids=input_ids, input_ids=input_ids,
@ -310,11 +416,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
embedding_manager = embedding_manager embedding_manager=embedding_manager,
) )
self.transformer.forward = transformer_forward.__get__(self.transformer) self.transformer.forward = transformer_forward.__get__(
self.transformer
)
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
@ -322,9 +429,16 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False param.requires_grad = False
def forward(self, text, **kwargs): def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
return z return z
@ -337,9 +451,17 @@ class FrozenCLIPTextEmbedder(nn.Module):
""" """
Uses the CLIP transformer encoder for text. Uses the CLIP transformer encoder for text.
""" """
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(
self,
version='ViT-L/14',
device='cuda',
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__() super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu") self.model, _ = clip.load(version, jit=False, device='cpu')
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.n_repeat = n_repeat self.n_repeat = n_repeat
@ -359,7 +481,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
def encode(self, text): def encode(self, text):
z = self(text) z = self(text)
if z.ndim==2: if z.ndim == 2:
z = z[:, None, :] z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z return z
@ -367,29 +489,42 @@ class FrozenCLIPTextEmbedder(nn.Module):
class FrozenClipImageEmbedder(nn.Module): class FrozenClipImageEmbedder(nn.Module):
""" """
Uses the CLIP image encoder. Uses the CLIP image encoder.
""" """
def __init__( def __init__(
self, self,
model, model,
jit=False, jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu', device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False, antialias=False,
): ):
super().__init__() super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit) self.model, _ = clip.load(name=model, device=device, jit=jit)
self.antialias = antialias self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer(
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False,
)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False,
)
def preprocess(self, x): def preprocess(self, x):
# normalize to [0,1] # normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224), x = kornia.geometry.resize(
interpolation='bicubic',align_corners=True, x,
antialias=self.antialias) (224, 224),
x = (x + 1.) / 2. interpolation='bicubic',
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip # renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std) x = kornia.enhance.normalize(x, self.mean, self.std)
return x return x
@ -399,7 +534,8 @@ class FrozenClipImageEmbedder(nn.Module):
return self.model.encode_image(self.preprocess(x)) return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__": if __name__ == '__main__':
from ldm.util import count_params from ldm.util import count_params
model = FrozenCLIPEmbedder() model = FrozenCLIPEmbedder()
count_params(model, verbose=True) count_params(model, verbose=True)

View File

@ -1,2 +1,6 @@
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr from ldm.modules.image_degradation.bsrgan import (
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf): def modcrop_np(img, sf):
''' """
Args: Args:
img: numpy image, WxH or WxHxC img: numpy image, WxH or WxHxC
sf: scale factor sf: scale factor
Return: Return:
cropped image cropped image
''' """
w, h = img.shape[:2] w, h = img.shape[:2]
im = np.copy(img) im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...] return im[: w - w % sf, : h - h % sf, ...]
""" """
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one # Loop over the small kernel to fill the big one
for r in range(k_size): for r in range(k_size):
for c in range(k_size): for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR # Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2 crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop] cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel """generate an anisotropic Gaussian kernel
Args: Args:
ksize : e.g., 15, kernel size ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range theta : [0, pi], rotation angle range
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel k : kernel
""" """
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]]) V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]]) D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k): def blur(x, k):
''' """
x: image, NxcxHxW x: image, NxcxHxW
k: kernel, Nx1xhxw k: kernel, Nx1xhxw
''' """
n, c = x.shape[:2] n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1) k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3]) k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3]) x = x.view(n, c, x.shape[2], x.shape[3])
return x return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): def gen_kernel(
"""" k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang # Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta # Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2]) LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], Q = np.array(
[np.sin(theta), np.cos(theta)]]) [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image) # Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None] MU = MU[None, None, :, None]
# Create meshgrid for Gaussian # Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize] hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) [x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std) arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg) h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0 h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs): def fspecial(filter_type, *args, **kwargs):
''' """
python code from: python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
''' """
if filter_type == 'gaussian': if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs) return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian': if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3): def bicubic_degradation(x, sf=3):
''' """
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
sf: down-scale factor sf: down-scale factor
Return: Return:
bicubicly downsampled LR image bicubicly downsampled LR image
''' """
x = util.imresize_np(x, scale=1 / sf) x = util.imresize_np(x, scale=1 / sf)
return x return x
def srmd_degradation(x, k, sf=3): def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling """blur + bicubic downsampling
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271}, pages={3262--3271},
year={2018} year={2018}
} }
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
return x return x
def dpsr_degradation(x, k, sf=3): def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur """bicubic downsampling + blur
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681}, pages={1671--1681},
year={2019} year={2019}
} }
''' """
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x return x
def classical_degradation(x, k, sf=3): def classical_degradation(x, k, sf=3):
''' blur + downsampling """blur + downsampling
Args: Args:
x: HxWxC image, [0, 1]/[0, 255] x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double k: hxw, double
sf: down-scale factor sf: down-scale factor
Return: Return:
downsampled LR image downsampled LR image
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0 st = 0
@ -328,10 +350,19 @@ def add_blur(img, sf=4):
if random.random() < 0.5: if random.random() < 0.5:
l1 = wd2 * random.random() l1 = wd2 * random.random()
l2 = wd2 * random.random() l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else: else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) k = fspecial(
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img return img
@ -344,7 +375,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1) sf1 = random.uniform(0.5 / sf, 1)
else: else:
sf1 = 1.0 sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -366,19 +401,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0) # img = np.clip(img, 0.0, 1.0)
# return img # return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2) noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand() rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise else: # add noise
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -388,28 +430,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
rnum = random.random() rnum = random.random()
if rnum > 0.6: if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4: elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: else:
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
def add_Poisson_noise(img): def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255. img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4] vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5: if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals img = np.random.poisson(img * vals).astype(np.float32) / vals
else: else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis] img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -418,7 +469,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img): def add_JPEG_noise(img):
quality_factor = random.randint(30, 95) quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1) img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img return img
@ -428,10 +481,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2] h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize) rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq return lq, hq
@ -452,7 +509,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf sf_ori = sf
h1, w1 = img.shape[:2] h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2] h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf: if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -462,8 +519,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
img = util.imresize_np(img, 1 / 2, True) img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
@ -472,7 +532,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -487,19 +550,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.75: if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -544,15 +618,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf sf_ori = sf
h1, w1 = image.shape[:2] h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2] h, w = image.shape[:2]
hq = image.copy() hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
image = util.imresize_np(image, 1 / 2, True) image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
@ -561,7 +638,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -576,19 +656,33 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.75: if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -609,12 +703,19 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise # add final JPEG compression noise
image = add_JPEG_noise(image) image = add_JPEG_noise(image)
image = util.single2uint(image) image = util.single2uint(image)
example = {"image":image} example = {'image': image}
return example return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): def degradation_bsrgan_plus(
img,
sf=4,
shuffle_prob=0.5,
use_sharp=True,
lq_patchsize=64,
isp_model=None,
):
""" """
This is an extended degradation model by combining This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN the degradation models of BSRGAN and Real-ESRGAN
@ -630,7 +731,7 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
""" """
h1, w1 = img.shape[:2] h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2] h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf: if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -645,8 +746,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
else: else:
shuffle_order = list(range(13)) shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one # local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) shuffle_order[2:6] = random.sample(
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) shuffle_order[2:6], len(range(2, 6))
)
shuffle_order[9:13] = random.sample(
shuffle_order[9:13], len(range(9, 13))
)
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
@ -689,8 +794,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
print('check the shuffle!') print('check the shuffle!')
# resize to desired size # resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise # add final JPEG compression noise
img = add_JPEG_noise(img) img = add_JPEG_noise(img)
@ -702,29 +810,37 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
if __name__ == '__main__': if __name__ == '__main__':
print("hey") print('hey')
img = util.imread_uint('utils/test.png', 3) img = util.imread_uint('utils/test.png', 3)
print(img) print(img)
img = util.uint2single(img) img = util.uint2single(img)
print(img) print(img)
img = img[:448, :448] img = img[:448, :448]
h = img.shape[0] // 4 h = img.shape[0] // 4
print("resizing to", h) print('resizing to', h)
sf = 4 sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf) deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20): for i in range(20):
print(i) print(i)
img_lq = deg_fn(img) img_lq = deg_fn(img)
print(img_lq) print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] img_lq_bicubic = albumentations.SmallestMaxSize(
print(img_lq.shape) max_size=h, interpolation=cv2.INTER_CUBIC
print("bicubic", img_lq_bicubic.shape) )(image=img)['image']
print(img_hq.shape) print(img_lq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), print('bicubic', img_lq_bicubic.shape)
interpolation=0) print(img_hq.shape)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), lq_nearest = cv2.resize(
interpolation=0) util.single2uint(img_lq),
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
util.imsave(img_concat, str(i) + '.png') interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf): def modcrop_np(img, sf):
''' """
Args: Args:
img: numpy image, WxH or WxHxC img: numpy image, WxH or WxHxC
sf: scale factor sf: scale factor
Return: Return:
cropped image cropped image
''' """
w, h = img.shape[:2] w, h = img.shape[:2]
im = np.copy(img) im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...] return im[: w - w % sf, : h - h % sf, ...]
""" """
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one # Loop over the small kernel to fill the big one
for r in range(k_size): for r in range(k_size):
for c in range(k_size): for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR # Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2 crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop] cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel """generate an anisotropic Gaussian kernel
Args: Args:
ksize : e.g., 15, kernel size ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range theta : [0, pi], rotation angle range
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel k : kernel
""" """
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]]) V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]]) D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k): def blur(x, k):
''' """
x: image, NxcxHxW x: image, NxcxHxW
k: kernel, Nx1xhxw k: kernel, Nx1xhxw
''' """
n, c = x.shape[:2] n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1) k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3]) k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3]) x = x.view(n, c, x.shape[2], x.shape[3])
return x return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): def gen_kernel(
"""" k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang # Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta # Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2]) LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], Q = np.array(
[np.sin(theta), np.cos(theta)]]) [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image) # Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None] MU = MU[None, None, :, None]
# Create meshgrid for Gaussian # Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize] hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) [x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std) arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg) h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0 h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs): def fspecial(filter_type, *args, **kwargs):
''' """
python code from: python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
''' """
if filter_type == 'gaussian': if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs) return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian': if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3): def bicubic_degradation(x, sf=3):
''' """
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
sf: down-scale factor sf: down-scale factor
Return: Return:
bicubicly downsampled LR image bicubicly downsampled LR image
''' """
x = util.imresize_np(x, scale=1 / sf) x = util.imresize_np(x, scale=1 / sf)
return x return x
def srmd_degradation(x, k, sf=3): def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling """blur + bicubic downsampling
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271}, pages={3262--3271},
year={2018} year={2018}
} }
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
return x return x
def dpsr_degradation(x, k, sf=3): def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur """bicubic downsampling + blur
Args: Args:
x: HxWxC image, [0, 1] x: HxWxC image, [0, 1]
k: hxw, double k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681}, pages={1671--1681},
year={2019} year={2019}
} }
''' """
x = bicubic_degradation(x, sf=sf) x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x return x
def classical_degradation(x, k, sf=3): def classical_degradation(x, k, sf=3):
''' blur + downsampling """blur + downsampling
Args: Args:
x: HxWxC image, [0, 1]/[0, 255] x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double k: hxw, double
sf: down-scale factor sf: down-scale factor
Return: Return:
downsampled LR image downsampled LR image
''' """
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0 st = 0
@ -326,16 +348,25 @@ def add_blur(img, sf=4):
wd2 = 4.0 + sf wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf wd = 2.0 + 0.2 * sf
wd2 = wd2/4 wd2 = wd2 / 4
wd = wd/4 wd = wd / 4
if random.random() < 0.5: if random.random() < 0.5:
l1 = wd2 * random.random() l1 = wd2 * random.random()
l2 = wd2 * random.random() l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else: else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) k = fspecial(
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 'gaussian', random.randint(2, 4) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img return img
@ -348,7 +379,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1) sf1 = random.uniform(0.5 / sf, 1)
else: else:
sf1 = 1.0 sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -370,19 +405,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0) # img = np.clip(img, 0.0, 1.0)
# return img # return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2) noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand() rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise else: # add noise
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -392,28 +434,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
rnum = random.random() rnum = random.random()
if rnum > 0.6: if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4: elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: else:
L = noise_level2 / 255. L = noise_level2 / 255.0
D = np.diag(np.random.rand(3)) D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3)) U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U) conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
def add_Poisson_noise(img): def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255. img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4] vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5: if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals img = np.random.poisson(img * vals).astype(np.float32) / vals
else: else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis] img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
return img return img
@ -422,7 +473,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img): def add_JPEG_noise(img):
quality_factor = random.randint(80, 95) quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1) img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img return img
@ -432,10 +485,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2] h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize) rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq return lq, hq
@ -456,7 +513,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf sf_ori = sf
h1, w1 = img.shape[:2] h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2] h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf: if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -466,8 +523,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
img = util.imresize_np(img, 1 / 2, True) img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
@ -476,7 +536,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -491,19 +554,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.75: if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), img = cv2.resize(
interpolation=random.choice([1, 2, 3])) img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -548,15 +622,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf sf_ori = sf
h1, w1 = image.shape[:2] h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2] h, w = image.shape[:2]
hq = image.copy() hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else: else:
image = util.imresize_np(image, 1 / 2, True) image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
@ -565,7 +642,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7) shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order: for i in shuffle_order:
@ -583,20 +663,34 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2 # downsample2
if random.random() < 0.8: if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf) sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), image = cv2.resize(
interpolation=random.choice([1, 2, 3])) image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else: else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf) k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel k_shifted = (
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 3: elif i == 3:
# downsample3 # downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0) image = np.clip(image, 0.0, 1.0)
elif i == 4: elif i == 4:
@ -617,34 +711,41 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise # add final JPEG compression noise
image = add_JPEG_noise(image) image = add_JPEG_noise(image)
image = util.single2uint(image) image = util.single2uint(image)
example = {"image": image} example = {'image': image}
return example return example
if __name__ == '__main__': if __name__ == '__main__':
print("hey") print('hey')
img = util.imread_uint('utils/test.png', 3) img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448] img = img[:448, :448]
h = img.shape[0] // 4 h = img.shape[0] // 4
print("resizing to", h) print('resizing to', h)
sf = 4 sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf) deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20): for i in range(20):
print(i) print(i)
img_hq = img img_hq = img
img_lq = deg_fn(img)["image"] img_lq = deg_fn(img)['image']
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq) print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)['image']
print(img_lq.shape) print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape) print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape) print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), lq_nearest = cv2.resize(
interpolation=0) util.single2uint(img_lq),
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0,
interpolation=0) )
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png') util.imsave(img_concat, str(i) + '.png')

View File

@ -6,13 +6,14 @@ import torch
import cv2 import cv2
from torchvision.utils import make_grid from torchvision.utils import make_grid
from datetime import datetime from datetime import datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
''' """
# -------------------------------------------- # --------------------------------------------
# Kai Zhang (github: https://github.com/cszn) # Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019 # 03/Mar/2019
@ -20,10 +21,22 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# https://github.com/twhui/SRGAN-pyTorch # https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR # https://github.com/xinntao/BasicSR
# -------------------------------------------- # --------------------------------------------
''' """
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
'.tif',
]
def is_image_file(filename): def is_image_file(filename):
@ -49,19 +62,19 @@ def surf(Z, cmap='rainbow', figsize=None):
ax3 = plt.axes(projection='3d') ax3 = plt.axes(projection='3d')
w, h = Z.shape[:2] w, h = Z.shape[:2]
xx = np.arange(0,w,1) xx = np.arange(0, w, 1)
yy = np.arange(0,h,1) yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy) X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X,Y,Z,cmap=cmap) ax3.plot_surface(X, Y, Z, cmap=cmap)
#ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap) # ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
plt.show() plt.show()
''' """
# -------------------------------------------- # --------------------------------------------
# get image pathes # get image pathes
# -------------------------------------------- # --------------------------------------------
''' """
def get_image_paths(dataroot): def get_image_paths(dataroot):
@ -83,26 +96,26 @@ def _get_paths_from_images(path):
return images return images
''' """
# -------------------------------------------- # --------------------------------------------
# split large images into small images # split large images into small images
# -------------------------------------------- # --------------------------------------------
''' """
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2] w, h = img.shape[:2]
patches = [] patches = []
if w > p_max and h > p_max: if w > p_max and h > p_max:
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
w1.append(w-p_size) w1.append(w - p_size)
h1.append(h-p_size) h1.append(h - p_size)
# print(w1) # print(w1)
# print(h1) # print(h1)
for i in w1: for i in w1:
for j in h1: for j in h1:
patches.append(img[i:i+p_size, j:j+p_size,:]) patches.append(img[i : i + p_size, j : j + p_size, :])
else: else:
patches.append(img) patches.append(img)
@ -118,11 +131,21 @@ def imssave(imgs, img_path):
for i, img in enumerate(imgs): for i, img in enumerate(imgs):
if img.ndim == 3: if img.ndim == 3:
img = img[:, :, [2, 1, 0]] img = img[:, :, [2, 1, 0]]
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') new_path = os.path.join(
os.path.dirname(img_path),
img_name + str('_s{:04d}'.format(i)) + '.png',
)
cv2.imwrite(new_path, img) cv2.imwrite(new_path, img)
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
""" """
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
@ -139,15 +162,18 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800,
# img_name, ext = os.path.splitext(os.path.basename(img_path)) # img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels) img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max) patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) imssave(
#if original_dataroot == taget_dataroot: patches, os.path.join(taget_dataroot, os.path.basename(img_path))
#del img_path )
# if original_dataroot == taget_dataroot:
# del img_path
'''
"""
# -------------------------------------------- # --------------------------------------------
# makedir # makedir
# -------------------------------------------- # --------------------------------------------
''' """
def mkdir(path): def mkdir(path):
@ -171,12 +197,12 @@ def mkdir_and_rename(path):
os.makedirs(path) os.makedirs(path)
''' """
# -------------------------------------------- # --------------------------------------------
# read image from path # read image from path
# opencv is fast, but read BGR numpy image # opencv is fast, but read BGR numpy image
# -------------------------------------------- # --------------------------------------------
''' """
# -------------------------------------------- # --------------------------------------------
@ -206,6 +232,7 @@ def imsave(img, img_path):
img = img[:, :, [2, 1, 0]] img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
def imwrite(img, img_path): def imwrite(img, img_path):
img = np.squeeze(img) img = np.squeeze(img)
if img.ndim == 3: if img.ndim == 3:
@ -213,7 +240,6 @@ def imwrite(img, img_path):
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
# -------------------------------------------- # --------------------------------------------
# get single image of size HxWxn_channles (BGR) # get single image of size HxWxn_channles (BGR)
# -------------------------------------------- # --------------------------------------------
@ -221,7 +247,7 @@ def read_img(path):
# read image by cv2 # read image by cv2
# return: Numpy float32, HWC, BGR, [0,1] # return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255. img = img.astype(np.float32) / 255.0
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
# some images have 4 channels # some images have 4 channels
@ -230,7 +256,7 @@ def read_img(path):
return img return img
''' """
# -------------------------------------------- # --------------------------------------------
# image format conversion # image format conversion
# -------------------------------------------- # --------------------------------------------
@ -238,7 +264,7 @@ def read_img(path):
# numpy(single) <---> tensor # numpy(single) <---> tensor
# numpy(unit) <---> tensor # numpy(unit) <---> tensor
# -------------------------------------------- # --------------------------------------------
''' """
# -------------------------------------------- # --------------------------------------------
@ -248,22 +274,22 @@ def read_img(path):
def uint2single(img): def uint2single(img):
return np.float32(img/255.) return np.float32(img / 255.0)
def single2uint(img): def single2uint(img):
return np.uint8((img.clip(0, 1)*255.).round()) return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img): def uint162single(img):
return np.float32(img/65535.) return np.float32(img / 65535.0)
def single2uint16(img): def single2uint16(img):
return np.uint16((img.clip(0, 1)*65535.).round()) return np.uint16((img.clip(0, 1) * 65535.0).round())
# -------------------------------------------- # --------------------------------------------
@ -275,14 +301,25 @@ def single2uint16(img):
def uint2tensor4(img): def uint2tensor4(img):
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
# convert uint to 3-dimensional torch tensor # convert uint to 3-dimensional torch tensor
def uint2tensor3(img): def uint2tensor3(img):
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
)
# convert 2/3/4-dimensional torch tensor to uint # convert 2/3/4-dimensional torch tensor to uint
@ -290,7 +327,7 @@ def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3: if img.ndim == 3:
img = np.transpose(img, (1, 2, 0)) img = np.transpose(img, (1, 2, 0))
return np.uint8((img*255.0).round()) return np.uint8((img * 255.0).round())
# -------------------------------------------- # --------------------------------------------
@ -305,7 +342,12 @@ def single2tensor3(img):
# convert single (HxWxC) to 4-dimensional torch tensor # convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img): def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
# convert torch tensor to single # convert torch tensor to single
@ -316,6 +358,7 @@ def tensor2single(img):
return img return img
# convert torch tensor to single # convert torch tensor to single
def tensor2single3(img): def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy() img = img.data.squeeze().float().cpu().numpy()
@ -327,30 +370,48 @@ def tensor2single3(img):
def single2tensor5(img): def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
def single32tensor5(img): def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) return (
torch.from_numpy(np.ascontiguousarray(img))
.float()
.unsqueeze(0)
.unsqueeze(0)
)
def single42tensor4(img): def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
)
# from skimage.io import imread, imsave # from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
''' """
Converts a torch Tensor into an image Numpy array of BGR channel order Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
''' """
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (
min_max[1] - min_max[0]
) # to range [0,1]
n_dim = tensor.dim() n_dim = tensor.dim()
if n_dim == 4: if n_dim == 4:
n_img = len(tensor) n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() img_np = make_grid(
tensor, nrow=int(math.sqrt(n_img)), normalize=False
).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3: elif n_dim == 3:
img_np = tensor.numpy() img_np = tensor.numpy()
@ -359,14 +420,17 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
img_np = tensor.numpy() img_np = tensor.numpy()
else: else:
raise TypeError( raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(
n_dim
)
)
if out_type == np.uint8: if out_type == np.uint8:
img_np = (img_np * 255.0).round() img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default. # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type) return img_np.astype(out_type)
''' """
# -------------------------------------------- # --------------------------------------------
# Augmentation, flipe and/or rotate # Augmentation, flipe and/or rotate
# -------------------------------------------- # --------------------------------------------
@ -374,12 +438,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
# (1) augmet_img: numpy image of WxHxC or WxH # (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH # (2) augment_img_tensor4: tensor image 1xCxWxH
# -------------------------------------------- # --------------------------------------------
''' """
def augment_img(img, mode=0): def augment_img(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn) """Kai Zhang (github: https://github.com/cszn)"""
'''
if mode == 0: if mode == 0:
return img return img
elif mode == 1: elif mode == 1:
@ -399,8 +462,7 @@ def augment_img(img, mode=0):
def augment_img_tensor4(img, mode=0): def augment_img_tensor4(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn) """Kai Zhang (github: https://github.com/cszn)"""
'''
if mode == 0: if mode == 0:
return img return img
elif mode == 1: elif mode == 1:
@ -420,8 +482,7 @@ def augment_img_tensor4(img, mode=0):
def augment_img_tensor(img, mode=0): def augment_img_tensor(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn) """Kai Zhang (github: https://github.com/cszn)"""
'''
img_size = img.size() img_size = img.size()
img_np = img.data.cpu().numpy() img_np = img.data.cpu().numpy()
if len(img_size) == 3: if len(img_size) == 3:
@ -484,11 +545,11 @@ def augment_imgs(img_list, hflip=True, rot=True):
return [_augment(img) for img in img_list] return [_augment(img) for img in img_list]
''' """
# -------------------------------------------- # --------------------------------------------
# modcrop and shave # modcrop and shave
# -------------------------------------------- # --------------------------------------------
''' """
def modcrop(img_in, scale): def modcrop(img_in, scale):
@ -497,11 +558,11 @@ def modcrop(img_in, scale):
if img.ndim == 2: if img.ndim == 2:
H, W = img.shape H, W = img.shape
H_r, W_r = H % scale, W % scale H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r] img = img[: H - H_r, : W - W_r]
elif img.ndim == 3: elif img.ndim == 3:
H, W, C = img.shape H, W, C = img.shape
H_r, W_r = H % scale, W % scale H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r, :] img = img[: H - H_r, : W - W_r, :]
else: else:
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
return img return img
@ -511,11 +572,11 @@ def shave(img_in, border=0):
# img_in: Numpy, HWC or HW # img_in: Numpy, HWC or HW
img = np.copy(img_in) img = np.copy(img_in)
h, w = img.shape[:2] h, w = img.shape[:2]
img = img[border:h-border, border:w-border] img = img[border : h - border, border : w - border]
return img return img
''' """
# -------------------------------------------- # --------------------------------------------
# image processing process on numpy image # image processing process on numpy image
# channel_convert(in_c, tar_type, img_list): # channel_convert(in_c, tar_type, img_list):
@ -523,74 +584,92 @@ def shave(img_in, border=0):
# bgr2ycbcr(img, only_y=True): # bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img): # ycbcr2rgb(img):
# -------------------------------------------- # --------------------------------------------
''' """
def rgb2ycbcr(img, only_y=True): def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr """same as matlab rgb2ycbcr
only_y: only return Y channel only_y: only return Y channel
Input: Input:
uint8, [0, 255] uint8, [0, 255]
float, [0, 1] float, [0, 1]
''' """
in_img_type = img.dtype in_img_type = img.dtype
img.astype(np.float32) img.astype(np.float32)
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255. img *= 255.0
# convert # convert
if only_y: if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else: else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], rlt = np.matmul(
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
rlt /= 255. rlt /= 255.0
return rlt.astype(in_img_type) return rlt.astype(in_img_type)
def ycbcr2rgb(img): def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb """same as matlab ycbcr2rgb
Input: Input:
uint8, [0, 255] uint8, [0, 255]
float, [0, 1] float, [0, 1]
''' """
in_img_type = img.dtype in_img_type = img.dtype
img.astype(np.float32) img.astype(np.float32)
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255. img *= 255.0
# convert # convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], rlt = np.matmul(
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
rlt /= 255. rlt /= 255.0
return rlt.astype(in_img_type) return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True): def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr """bgr version of rgb2ycbcr
only_y: only return Y channel only_y: only return Y channel
Input: Input:
uint8, [0, 255] uint8, [0, 255]
float, [0, 1] float, [0, 1]
''' """
in_img_type = img.dtype in_img_type = img.dtype
img.astype(np.float32) img.astype(np.float32)
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255. img *= 255.0
# convert # convert
if only_y: if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else: else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], rlt = np.matmul(
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
rlt /= 255. rlt /= 255.0
return rlt.astype(in_img_type) return rlt.astype(in_img_type)
@ -608,11 +687,11 @@ def channel_convert(in_c, tar_type, img_list):
return img_list return img_list
''' """
# -------------------------------------------- # --------------------------------------------
# metric, PSNR and SSIM # metric, PSNR and SSIM
# -------------------------------------------- # --------------------------------------------
''' """
# -------------------------------------------- # --------------------------------------------
@ -620,17 +699,17 @@ def channel_convert(in_c, tar_type, img_list):
# -------------------------------------------- # --------------------------------------------
def calculate_psnr(img1, img2, border=0): def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255] # img1 and img2 have range [0, 255]
#img1 = img1.squeeze() # img1 = img1.squeeze()
#img2 = img2.squeeze() # img2 = img2.squeeze()
if not img1.shape == img2.shape: if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.') raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2] h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border] img1 = img1[border : h - border, border : w - border]
img2 = img2[border:h-border, border:w-border] img2 = img2[border : h - border, border : w - border]
img1 = img1.astype(np.float64) img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64) img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2) mse = np.mean((img1 - img2) ** 2)
if mse == 0: if mse == 0:
return float('inf') return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse)) return 20 * math.log10(255.0 / math.sqrt(mse))
@ -640,17 +719,17 @@ def calculate_psnr(img1, img2, border=0):
# SSIM # SSIM
# -------------------------------------------- # --------------------------------------------
def calculate_ssim(img1, img2, border=0): def calculate_ssim(img1, img2, border=0):
'''calculate SSIM """calculate SSIM
the same outputs as MATLAB's the same outputs as MATLAB's
img1, img2: [0, 255] img1, img2: [0, 255]
''' """
#img1 = img1.squeeze() # img1 = img1.squeeze()
#img2 = img2.squeeze() # img2 = img2.squeeze()
if not img1.shape == img2.shape: if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.') raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2] h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border] img1 = img1[border : h - border, border : w - border]
img2 = img2[border:h-border, border:w-border] img2 = img2[border : h - border, border : w - border]
if img1.ndim == 2: if img1.ndim == 2:
return ssim(img1, img2) return ssim(img1, img2)
@ -658,7 +737,7 @@ def calculate_ssim(img1, img2, border=0):
if img1.shape[2] == 3: if img1.shape[2] == 3:
ssims = [] ssims = []
for i in range(3): for i in range(3):
ssims.append(ssim(img1[:,:,i], img2[:,:,i])) ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean() return np.array(ssims).mean()
elif img1.shape[2] == 1: elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2)) return ssim(np.squeeze(img1), np.squeeze(img2))
@ -667,8 +746,8 @@ def calculate_ssim(img1, img2, border=0):
def ssim(img1, img2): def ssim(img1, img2):
C1 = (0.01 * 255)**2 C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255)**2 C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64) img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64) img2 = img2.astype(np.float64)
@ -684,16 +763,17 @@ def ssim(img1, img2):
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(sigma1_sq + sigma2_sq + C2)) (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
return ssim_map.mean() return ssim_map.mean()
''' """
# -------------------------------------------- # --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1] # matlab's bicubic imresize (numpy and torch) [0, 1]
# -------------------------------------------- # --------------------------------------------
''' """
# matlab 'imresize' function, now only support 'bicubic' # matlab 'imresize' function, now only support 'bicubic'
@ -701,11 +781,14 @@ def cubic(x):
absx = torch.abs(x) absx = torch.abs(x)
absx2 = absx**2 absx2 = absx**2
absx3 = absx**3 absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
if (scale < 1) and (antialiasing): if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale kernel_width = kernel_width / scale
@ -729,8 +812,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width
# The indices of the input pixels involved in computing the k-th output # The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix. # pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
1, P).expand(out_length, P) 0, P - 1, P
).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the # The weights used to compute the k-th output pixel are in row k of the
# weights matrix. # weights matrix.
@ -771,7 +855,11 @@ def imresize(img, scale, antialiasing=True):
if need_squeeze: if need_squeeze:
img.unsqueeze_(0) img.unsqueeze_(0)
in_C, in_H, in_W = img.size() in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4 kernel_width = 4
kernel = 'cubic' kernel = 'cubic'
@ -782,9 +870,11 @@ def imresize(img, scale, antialiasing=True):
# get weights and indices # get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing) in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing) in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension # process H dimension
# symmetric copying # symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
@ -805,7 +895,11 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_H): for i in range(out_H):
idx = int(indices_H[i][0]) idx = int(indices_H[i][0])
for j in range(out_C): for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension # process W dimension
# symmetric copying # symmetric copying
@ -827,7 +921,9 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_W): for i in range(out_W):
idx = int(indices_W[i][0]) idx = int(indices_W[i][0])
for j in range(out_C): for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(
weights_W[i]
)
if need_squeeze: if need_squeeze:
out_2.squeeze_() out_2.squeeze_()
return out_2 return out_2
@ -846,7 +942,11 @@ def imresize_np(img, scale, antialiasing=True):
img.unsqueeze_(2) img.unsqueeze_(2)
in_H, in_W, in_C = img.size() in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4 kernel_width = 4
kernel = 'cubic' kernel = 'cubic'
@ -857,9 +957,11 @@ def imresize_np(img, scale, antialiasing=True):
# get weights and indices # get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing) in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing) in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension # process H dimension
# symmetric copying # symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
@ -880,7 +982,11 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_H): for i in range(out_H):
idx = int(indices_H[i][0]) idx = int(indices_H[i][0])
for j in range(out_C): for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension # process W dimension
# symmetric copying # symmetric copying
@ -902,7 +1008,9 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_W): for i in range(out_W):
idx = int(indices_W[i][0]) idx = int(indices_W[i][0])
for j in range(out_C): for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(
weights_W[i]
)
if need_squeeze: if need_squeeze:
out_2.squeeze_() out_2.squeeze_()
@ -913,4 +1021,4 @@ if __name__ == '__main__':
print('---') print('---')
# img = imread_uint('test.bmp', 3) # img = imread_uint('test.bmp', 3)
# img = uint2single(img) # img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4) # img_bicubic = imresize_np(img, 1/4)

View File

@ -1 +1 @@
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator

View File

@ -5,13 +5,24 @@ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/
class LPIPSWithDiscriminator(nn.Module): class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, def __init__(
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, self,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_start,
disc_loss="hinge"): logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss='hinge',
):
super().__init__() super().__init__()
assert disc_loss in ["hinge", "vanilla"] assert disc_loss in ['hinge', 'vanilla']
self.kl_weight = kl_weight self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval() self.perceptual_loss = LPIPS().eval()
@ -19,42 +30,68 @@ class LPIPSWithDiscriminator(nn.Module):
# output log variance # output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, self.discriminator = NLayerDiscriminator(
n_layers=disc_num_layers, input_nc=disc_in_channels,
use_actnorm=use_actnorm n_layers=disc_num_layers,
).apply(weights_init) use_actnorm=use_actnorm,
).apply(weights_init)
self.discriminator_iter_start = disc_start self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_loss = (
hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss
)
self.disc_factor = disc_factor self.disc_factor = disc_factor
self.discriminator_weight = disc_weight self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None: if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else: else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight d_weight = d_weight * self.discriminator_weight
return d_weight return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx, def forward(
global_step, last_layer=None, cond=None, split="train", self,
weights=None): inputs,
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
weights=None,
):
rec_loss = torch.abs(
inputs.contiguous() - reconstructions.contiguous()
)
if self.perceptual_weight > 0: if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss weighted_nll_loss = nll_loss
if weights is not None: if weights is not None:
weighted_nll_loss = weights*nll_loss weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] weighted_nll_loss = (
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
)
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl() kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
@ -67,45 +104,72 @@ class LPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous()) logits_fake = self.discriminator(reconstructions.contiguous())
else: else:
assert self.disc_conditional assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake) g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0: if self.disc_factor > 0.0:
try: try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError: except RuntimeError:
assert not self.training assert not self.training
d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0)
else: else:
d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), log = {
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), '{}/total_loss'.format(split): loss.clone().detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(), '{}/logvar'.format(split): self.logvar.detach(),
"{}/d_weight".format(split): d_weight.detach(), '{}/kl_loss'.format(split): kl_loss.detach().mean(),
"{}/disc_factor".format(split): torch.tensor(disc_factor), '{}/nll_loss'.format(split): nll_loss.detach().mean(),
"{}/g_loss".format(split): g_loss.detach().mean(), '{}/rec_loss'.format(split): rec_loss.detach().mean(),
} '{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
return loss, log return loss, log
if optimizer_idx == 1: if optimizer_idx == 1:
# second pass for discriminator update # second pass for discriminator update
if cond is None: if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach()) logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else: else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_real = self.discriminator(
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), log = {
"{}/logits_real".format(split): logits_real.detach().mean(), '{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean() '{}/logits_real'.format(split): logits_real.detach().mean(),
} '{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log return d_loss, log

View File

@ -3,21 +3,25 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import repeat from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from taming.modules.discriminator.model import (
NLayerDiscriminator,
weights_init,
)
from taming.modules.losses.lpips import LPIPS from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum() loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake) d_loss = 0.5 * (loss_real + loss_fake)
return d_loss return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold: if global_step < threshold:
weight = value weight = value
return weight return weight
@ -26,57 +30,76 @@ def adopt_weight(weight, global_step, threshold=0, value=0.):
def measure_perplexity(predicted_indices, n_embed): def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) encodings = (
F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
)
avg_probs = encodings.mean(0) avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0) cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use return perplexity, cluster_use
def l1(x, y): def l1(x, y):
return torch.abs(x-y) return torch.abs(x - y)
def l2(x, y): def l2(x, y):
return torch.pow((x-y), 2) return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module): class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, def __init__(
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, self,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_start,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", codebook_weight=1.0,
pixel_loss="l1"): pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss='hinge',
n_classes=None,
perceptual_loss='lpips',
pixel_loss='l1',
):
super().__init__() super().__init__()
assert disc_loss in ["hinge", "vanilla"] assert disc_loss in ['hinge', 'vanilla']
assert perceptual_loss in ["lpips", "clips", "dists"] assert perceptual_loss in ['lpips', 'clips', 'dists']
assert pixel_loss in ["l1", "l2"] assert pixel_loss in ['l1', 'l2']
self.codebook_weight = codebook_weight self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips": if perceptual_loss == 'lpips':
print(f"{self.__class__.__name__}: Running with LPIPS.") print(f'{self.__class__.__name__}: Running with LPIPS.')
self.perceptual_loss = LPIPS().eval() self.perceptual_loss = LPIPS().eval()
else: else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") raise ValueError(
f'Unknown perceptual loss: >> {perceptual_loss} <<'
)
self.perceptual_weight = perceptual_weight self.perceptual_weight = perceptual_weight
if pixel_loss == "l1": if pixel_loss == 'l1':
self.pixel_loss = l1 self.pixel_loss = l1
else: else:
self.pixel_loss = l2 self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, self.discriminator = NLayerDiscriminator(
n_layers=disc_num_layers, input_nc=disc_in_channels,
use_actnorm=use_actnorm, n_layers=disc_num_layers,
ndf=disc_ndf use_actnorm=use_actnorm,
).apply(weights_init) ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start self.discriminator_iter_start = disc_start
if disc_loss == "hinge": if disc_loss == 'hinge':
self.disc_loss = hinge_d_loss self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla": elif disc_loss == 'vanilla':
self.disc_loss = vanilla_d_loss self.disc_loss = vanilla_d_loss
else: else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.") raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.')
self.disc_factor = disc_factor self.disc_factor = disc_factor
self.discriminator_weight = disc_weight self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional self.disc_conditional = disc_conditional
@ -84,31 +107,53 @@ class VQLPIPSWithDiscriminator(nn.Module):
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None: if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else: else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] nll_grads = torch.autograd.grad(
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight d_weight = d_weight * self.discriminator_weight
return d_weight return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, def forward(
global_step, last_layer=None, cond=None, split="train", predicted_indices=None): self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
predicted_indices=None,
):
if not exists(codebook_loss): if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device) codebook_loss = torch.tensor([0.0]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) rec_loss = self.pixel_loss(
inputs.contiguous(), reconstructions.contiguous()
)
if self.perceptual_weight > 0: if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss rec_loss = rec_loss + self.perceptual_weight * p_loss
else: else:
p_loss = torch.tensor([0.0]) p_loss = torch.tensor([0.0])
nll_loss = rec_loss nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss) nll_loss = torch.mean(nll_loss)
# now the GAN part # now the GAN part
@ -119,49 +164,77 @@ class VQLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous()) logits_fake = self.discriminator(reconstructions.contiguous())
else: else:
assert self.disc_conditional assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake) g_loss = -torch.mean(logits_fake)
try: try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError: except RuntimeError:
assert not self.training assert not self.training
d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), log = {
"{}/quant_loss".format(split): codebook_loss.detach().mean(), '{}/total_loss'.format(split): loss.clone().detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(), '{}/quant_loss'.format(split): codebook_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(), '{}/nll_loss'.format(split): nll_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(), '{}/rec_loss'.format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(), '{}/p_loss'.format(split): p_loss.detach().mean(),
"{}/disc_factor".format(split): torch.tensor(disc_factor), '{}/d_weight'.format(split): d_weight.detach(),
"{}/g_loss".format(split): g_loss.detach().mean(), '{}/disc_factor'.format(split): torch.tensor(disc_factor),
} '{}/g_loss'.format(split): g_loss.detach().mean(),
}
if predicted_indices is not None: if predicted_indices is not None:
assert self.n_classes is not None assert self.n_classes is not None
with torch.no_grad(): with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) perplexity, cluster_usage = measure_perplexity(
log[f"{split}/perplexity"] = perplexity predicted_indices, self.n_classes
log[f"{split}/cluster_usage"] = cluster_usage )
log[f'{split}/perplexity'] = perplexity
log[f'{split}/cluster_usage'] = cluster_usage
return loss, log return loss, log
if optimizer_idx == 1: if optimizer_idx == 1:
# second pass for discriminator update # second pass for discriminator update
if cond is None: if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach()) logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else: else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_real = self.discriminator(
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), log = {
"{}/logits_real".format(split): logits_real.detach().mean(), '{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean() '{}/logits_real'.format(split): logits_real.detach().mean(),
} '{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log return d_loss, log

View File

@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [ Intermediates = namedtuple(
'pre_softmax_attn', 'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']
'post_softmax_attn' )
])
LayerIntermediates = namedtuple('Intermediates', [ LayerIntermediates = namedtuple(
'hiddens', 'Intermediates', ['hiddens', 'attn_intermediates']
'attn_intermediates' )
])
class AbsolutePositionalEmbedding(nn.Module): class AbsolutePositionalEmbedding(nn.Module):
@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0): def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(
self.inv_freq
)
+ offset
)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :] return emb[None, :, :]
@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers # helpers
def exists(val): def exists(val):
return val is not None return val is not None
@ -64,18 +68,21 @@ def default(val, d):
def always(val): def always(val):
def inner(*args, **kwargs): def inner(*args, **kwargs):
return val return val
return inner return inner
def not_equals(val): def not_equals(val):
def inner(x): def inner(x):
return x != val return x != val
return inner return inner
def equals(val): def equals(val):
def inner(x): def inner(x):
return x == val return x == val
return inner return inner
@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers # keyword argument helpers
def pick_and_pop(keys, d): def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys)) values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values)) return dict(zip(keys, values))
@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d): def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_with_prefix, kwargs = group_dict_by_key(
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(
lambda x: (x[0][len(prefix) :], x[1]),
tuple(kwargs_with_prefix.items()),
)
)
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
@ -139,7 +154,7 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module): class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5): def __init__(self, dim, eps=1e-5):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(1)) self.g = nn.Parameter(torch.ones(1))
@ -151,7 +166,7 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8): def __init__(self, dim, eps=1e-8):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
@ -173,7 +188,7 @@ class GRUGating(nn.Module):
def forward(self, x, residual): def forward(self, x, residual):
gated_output = self.gru( gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'), rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d') rearrange(residual, 'b n d -> (b n) d'),
) )
return gated_output.reshape_as(x) return gated_output.reshape_as(x)
@ -181,6 +196,7 @@ class GRUGating(nn.Module):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super().__init__() super().__init__()
@ -192,19 +208,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = (
nn.Linear(dim, inner_dim), nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
nn.GELU() if not glu
) if not glu else GEGLU(dim, inner_dim) else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
) )
def forward(self, x): def forward(self, x):
@ -214,23 +229,25 @@ class FeedForward(nn.Module):
# attention. # attention.
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
dim_head=DEFAULT_DIM_HEAD, dim_head=DEFAULT_DIM_HEAD,
heads=8, heads=8,
causal=False, causal=False,
mask=None, mask=None,
talking_heads=False, talking_heads=False,
sparse_topk=None, sparse_topk=None,
use_entmax15=False, use_entmax15=False,
num_mem_kv=0, num_mem_kv=0,
dropout=0., dropout=0.0,
on_attn=False on_attn=False,
): ):
super().__init__() super().__init__()
if use_entmax15: if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!") raise NotImplementedError(
self.scale = dim_head ** -0.5 'Check out entmax activation instead of softmax activation!'
)
self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.causal = causal self.causal = causal
self.mask = mask self.mask = mask
@ -252,7 +269,7 @@ class Attention(nn.Module):
self.sparse_topk = sparse_topk self.sparse_topk = sparse_topk
# entmax # entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax # self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax self.attn_fn = F.softmax
# add memory key / values # add memory key / values
@ -263,20 +280,29 @@ class Attention(nn.Module):
# attention on attention # attention on attention
self.attn_on_attn = on_attn self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward( def forward(
self, self,
x, x,
context=None, context=None,
mask=None, mask=None,
context_mask=None, context_mask=None,
rel_pos=None, rel_pos=None,
sinusoidal_emb=None, sinusoidal_emb=None,
prev_attn=None, prev_attn=None,
mem=None mem=None,
): ):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x) kv_input = default(context, x)
q_input = x q_input = x
@ -297,23 +323,35 @@ class Attention(nn.Module):
k = self.to_k(k_input) k = self.to_k(k_input)
v = self.to_v(v_input) v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)
)
input_mask = None input_mask = None
if any(map(exists, (mask, context_mask))): if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) q_mask = default(
mask, lambda: torch.ones((b, n), device=device).bool()
)
k_mask = q_mask if not exists(context) else context_mask k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) k_mask = default(
k_mask,
lambda: torch.ones((b, k.shape[-2]), device=device).bool(),
)
q_mask = rearrange(q_mask, 'b i -> b () i ()') q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j') k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask input_mask = q_mask * k_mask
if self.num_mem_kv > 0: if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) mem_k, mem_v = map(
lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v),
)
k = torch.cat((mem_k, k), dim=-2) k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2) v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask): if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True
)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots) mask_value = max_neg_value(dots)
@ -324,7 +362,9 @@ class Attention(nn.Module):
pre_softmax_attn = dots pre_softmax_attn = dots
if talking_heads: if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() dots = einsum(
'b h i j, h k -> b k i j', dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos): if exists(rel_pos):
dots = rel_pos(dots) dots = rel_pos(dots)
@ -336,7 +376,9 @@ class Attention(nn.Module):
if self.causal: if self.causal:
i, j = dots.shape[-2:] i, j = dots.shape[-2:]
r = torch.arange(i, device=device) r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j'
)
mask = F.pad(mask, (j - i, 0), value=False) mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value) dots.masked_fill_(mask, mask_value)
del mask del mask
@ -354,14 +396,16 @@ class Attention(nn.Module):
attn = self.dropout(attn) attn = self.dropout(attn)
if talking_heads: if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() attn = einsum(
'b h i j, h k -> b k i j', attn, self.post_softmax_proj
).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates( intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn, pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn post_softmax_attn=post_softmax_attn,
) )
return self.to_out(out), intermediates return self.to_out(out), intermediates
@ -369,28 +413,28 @@ class Attention(nn.Module):
class AttentionLayers(nn.Module): class AttentionLayers(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
depth, depth,
heads=8, heads=8,
causal=False, causal=False,
cross_attend=False, cross_attend=False,
only_cross=False, only_cross=False,
use_scalenorm=False, use_scalenorm=False,
use_rmsnorm=False, use_rmsnorm=False,
use_rezero=False, use_rezero=False,
rel_pos_num_buckets=32, rel_pos_num_buckets=32,
rel_pos_max_distance=128, rel_pos_max_distance=128,
position_infused_attn=False, position_infused_attn=False,
custom_layers=None, custom_layers=None,
sandwich_coef=None, sandwich_coef=None,
par_ratio=None, par_ratio=None,
residual_attn=False, residual_attn=False,
cross_residual_attn=False, cross_residual_attn=False,
macaron=False, macaron=False,
pre_norm=True, pre_norm=True,
gate_residual=False, gate_residual=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
@ -403,10 +447,14 @@ class AttentionLayers(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None) self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' assert (
rel_pos_num_buckets <= rel_pos_max_distance
), 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None self.rel_pos = None
self.pre_norm = pre_norm self.pre_norm = pre_norm
@ -438,15 +486,27 @@ class AttentionLayers(nn.Module):
assert 1 < par_ratio <= par_depth, 'par ratio out of range' assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block)) default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio' assert (
par_block = default_block + ('f',) * (par_width - len(default_block)) len(default_block) <= par_width
), 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (
par_width - len(default_block)
)
par_head = par_block * par_attn par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head)) layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef): elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' assert (
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef sandwich_coef > 0 and sandwich_coef <= depth
), 'sandwich coefficient should be less than the depth'
layer_types = (
('a',) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ('f',) * sandwich_coef
)
else: else:
layer_types = default_block * depth layer_types = default_block * depth
@ -455,7 +515,9 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types: for layer_type in self.layer_types:
if layer_type == 'a': if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs
)
elif layer_type == 'c': elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs) layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f': elif layer_type == 'f':
@ -472,21 +534,17 @@ class AttentionLayers(nn.Module):
else: else:
residual_fn = Residual() residual_fn = Residual()
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
norm_fn(),
layer,
residual_fn
]))
def forward( def forward(
self, self,
x, x,
context=None, context=None,
mask=None, mask=None,
context_mask=None, context_mask=None,
mems=None, mems=None,
return_hiddens=False, return_hiddens=False,
**kwargs **kwargs,
): ):
hiddens = [] hiddens = []
intermediates = [] intermediates = []
@ -495,7 +553,9 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1) is_last = ind == (len(self.layers) - 1)
if layer_type == 'a': if layer_type == 'a':
@ -508,10 +568,22 @@ class AttentionLayers(nn.Module):
x = norm(x) x = norm(x)
if layer_type == 'a': if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, out, inter = block(
prev_attn=prev_attn, mem=layer_mem) x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == 'c': elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == 'f': elif layer_type == 'f':
out = block(x) out = block(x)
@ -530,8 +602,7 @@ class AttentionLayers(nn.Module):
if return_hiddens: if return_hiddens:
intermediates = LayerIntermediates( intermediates = LayerIntermediates(
hiddens=hiddens, hiddens=hiddens, attn_intermediates=intermediates
attn_intermediates=intermediates
) )
return x, intermediates return x, intermediates
@ -545,23 +616,24 @@ class Encoder(AttentionLayers):
super().__init__(causal=False, **kwargs) super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module): class TransformerWrapper(nn.Module):
def __init__( def __init__(
self, self,
*, *,
num_tokens, num_tokens,
max_seq_len, max_seq_len,
attn_layers, attn_layers,
emb_dim=None, emb_dim=None,
max_mem_len=0., max_mem_len=0.0,
emb_dropout=0., emb_dropout=0.0,
num_memory_tokens=None, num_memory_tokens=None,
tie_embedding=False, tie_embedding=False,
use_pos_emb=True use_pos_emb=True,
): ):
super().__init__() super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim dim = attn_layers.dim
emb_dim = default(emb_dim, dim) emb_dim = default(emb_dim, dim)
@ -571,23 +643,34 @@ class TransformerWrapper(nn.Module):
self.num_tokens = num_tokens self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim) self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( self.pos_emb = (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.project_emb = (
nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
)
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.init_() self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper # memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0) num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0: if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim)
)
# let funnel encoder know number of memory tokens, if specified # let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'): if hasattr(attn_layers, 'num_memory_tokens'):
@ -597,20 +680,20 @@ class TransformerWrapper(nn.Module):
nn.init.normal_(self.token_emb.weight, std=0.02) nn.init.normal_(self.token_emb.weight, std=0.02)
def forward( def forward(
self, self,
x, x,
return_embeddings=False, return_embeddings=False,
mask=None, mask=None,
return_mems=False, return_mems=False,
return_attn=False, return_attn=False,
mems=None, mems=None,
embedding_manager=None, embedding_manager=None,
**kwargs **kwargs,
): ):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
embedded_x = self.token_emb(x) embedded_x = self.token_emb(x)
if embedding_manager: if embedding_manager:
x = embedding_manager(x, embedded_x) x = embedding_manager(x, embedded_x)
else: else:
@ -629,7 +712,9 @@ class TransformerWrapper(nn.Module):
if exists(mask): if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True) mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x) x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:] mem, x = x[:, :num_mem], x[:, num_mem:]
@ -638,13 +723,30 @@ class TransformerWrapper(nn.Module):
if return_mems: if return_mems:
hiddens = intermediates.hiddens hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens new_mems = (
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) list(
map(
lambda pair: torch.cat(pair, dim=-2),
zip(mems, hiddens),
)
)
if exists(mems)
else hiddens
)
new_mems = list(
map(
lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems
)
)
return out, new_mems return out, new_mems
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = list(
map(
lambda t: t.post_softmax_attn,
intermediates.attn_intermediates,
)
)
return out, attn_maps return out, attn_maps
return out return out

View File

@ -24,10 +24,10 @@ import re
import traceback import traceback
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
@ -93,67 +93,69 @@ still work.
class T2I: class T2I:
"""T2I class """T2I class
Attributes Attributes
---------- ----------
model model
config config
iterations iterations
batch_size batch_size
steps steps
seed seed
sampler_name sampler_name
width width
height height
cfg_scale cfg_scale
latent_channels latent_channels
downsampling_factor downsampling_factor
precision precision
strength strength
embedding_path embedding_path
The vast majority of these arguments default to reasonable values. The vast majority of these arguments default to reasonable values.
""" """
def __init__(self,
batch_size=1, def __init__(
iterations = 1, self,
steps=50, batch_size=1,
seed=None, iterations=1,
cfg_scale=7.5, steps=50,
weights="models/ldm/stable-diffusion-v1/model.ckpt", seed=None,
config = "configs/stable-diffusion/v1-inference.yaml", cfg_scale=7.5,
width=512, weights='models/ldm/stable-diffusion-v1/model.ckpt',
height=512, config='configs/stable-diffusion/v1-inference.yaml',
sampler_name="klms", width=512,
latent_channels=4, height=512,
downsampling_factor=8, sampler_name='klms',
ddim_eta=0.0, # deterministic latent_channels=4,
precision='autocast', downsampling_factor=8,
full_precision=False, ddim_eta=0.0, # deterministic
strength=0.75, # default in scripts/img2img.py precision='autocast',
embedding_path=None, full_precision=False,
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt strength=0.75, # default in scripts/img2img.py
device='cuda', embedding_path=None,
gfpgan=None, latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
device='cuda',
gfpgan=None,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.iterations = iterations self.iterations = iterations
self.width = width self.width = width
self.height = height self.height = height
self.steps = steps self.steps = steps
self.cfg_scale = cfg_scale self.cfg_scale = cfg_scale
self.weights = weights self.weights = weights
self.config = config self.config = config
self.sampler_name = sampler_name self.sampler_name = sampler_name
self.latent_channels = latent_channels self.latent_channels = latent_channels
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.full_precision = full_precision self.full_precision = full_precision
self.strength = strength self.strength = strength
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
self.latent_diffusion_weights=latent_diffusion_weights self.latent_diffusion_weights = latent_diffusion_weights
self.device = device self.device = device
self.gfpgan = gfpgan self.gfpgan = gfpgan
if seed is None: if seed is None:
@ -162,49 +164,55 @@ The vast majority of these arguments default to reasonable values.
self.seed = seed self.seed = seed
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
def prompt2png(self,prompt,outdir,**kwargs): def prompt2png(self, prompt, outdir, **kwargs):
''' """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...] of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to T2I and prompt2image() Optional named arguments are the same as those passed to T2I and prompt2image()
''' """
results = self.prompt2image(prompt,**kwargs) results = self.prompt2image(prompt, **kwargs)
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size)) pngwriter = PngWriter(
outdir, prompt, kwargs.get('batch_size', self.batch_size)
)
for r in results: for r in results:
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
pngwriter.write_image(r[0],r[1]) pngwriter.write_image(r[0], r[1])
return pngwriter.files_written return pngwriter.files_written
def txt2img(self,prompt,**kwargs): def txt2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir','outputs/img-samples') outdir = kwargs.get('outdir', 'outputs/img-samples')
return self.prompt2png(prompt,outdir,**kwargs) return self.prompt2png(prompt, outdir, **kwargs)
def img2img(self,prompt,**kwargs): def img2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir','outputs/img-samples') outdir = kwargs.get('outdir', 'outputs/img-samples')
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' assert (
return self.prompt2png(prompt,outdir,**kwargs) 'init_img' in kwargs
), 'call to img2img() must include the init_img argument'
return self.prompt2png(prompt, outdir, **kwargs)
def prompt2image(self, def prompt2image(
# these are common self,
prompt, # these are common
batch_size=None, prompt,
iterations=None, batch_size=None,
steps=None, iterations=None,
seed=None, steps=None,
cfg_scale=None, seed=None,
ddim_eta=None, cfg_scale=None,
skip_normalize=False, ddim_eta=None,
image_callback=None, skip_normalize=False,
# these are specific to txt2img image_callback=None,
width=None, # these are specific to txt2img
height=None, width=None,
# these are specific to img2img height=None,
init_img=None, # these are specific to img2img
strength=None, init_img=None,
gfpgan_strength=None, strength=None,
variants=None, gfpgan_strength=None,
**args): # eat up additional cruft variants=None,
''' **args,
): # eat up additional cruft
"""
ldm.prompt2image() is the common entry point for txt2img() and img2img() ldm.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments: It takes the following arguments:
prompt // prompt string (no default) prompt // prompt string (no default)
@ -232,118 +240,157 @@ The vast majority of these arguments default to reasonable values.
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
to create the requested output directory, select a unique informative name for each image, and to create the requested output directory, select a unique informative name for each image, and
write the prompt into the PNG metadata. write the prompt into the PNG metadata.
''' """
steps = steps or self.steps steps = steps or self.steps
seed = seed or self.seed seed = seed or self.seed
width = width or self.width width = width or self.width
height = height or self.height height = height or self.height
cfg_scale = cfg_scale or self.cfg_scale cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength strength = strength or self.strength
model = self.load_model() # will instantiate the model or return it from cache model = (
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" self.load_model()
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' ) # will instantiate the model or return it from cache
w = int(width/64) * 64 assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
h = int(height/64) * 64 assert (
0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]'
w = int(width / 64) * 64
h = int(height / 64) * 64
if h != height or w != width: if h != height or w != width:
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}') print(
f'Height and width must be multiples of 64. Resizing to {h}x{w}'
)
height = h height = h
width = w width = w
scope = autocast if self.precision=="autocast" else nullcontext scope = autocast if self.precision == 'autocast' else nullcontext
tic = time.time() tic = time.time()
results = list() results = list()
try: try:
if init_img: if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found' assert os.path.exists(init_img), f'{init_img}: File not found'
images_iterator = self._img2img(prompt, images_iterator = self._img2img(
precision_scope=scope, prompt,
batch_size=batch_size, precision_scope=scope,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, batch_size=batch_size,
skip_normalize=skip_normalize, steps=steps,
init_img=init_img,strength=strength) cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,
strength=strength,
)
else: else:
images_iterator = self._txt2img(prompt, images_iterator = self._txt2img(
precision_scope=scope, prompt,
batch_size=batch_size, precision_scope=scope,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, batch_size=batch_size,
skip_normalize=skip_normalize, steps=steps,
width=width,height=height) cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,
height=height,
)
with scope(self.device.type), self.model.ema_scope(): with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"): for n in trange(iterations, desc='Sampling'):
seed_everything(seed) seed_everything(seed)
iter_images = next(images_iterator) iter_images = next(images_iterator)
for image in iter_images: for image in iter_images:
try: try:
if gfpgan_strength > 0: if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength) image = self._run_gfpgan(
image, gfpgan_strength
)
except Exception as e: except Exception as e:
print(f"Error running GFPGAN - Your image was not enhanced.\n{e}") print(
f'Error running GFPGAN - Your image was not enhanced.\n{e}'
)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image,seed) image_callback(image, seed)
seed = self._new_seed() seed = self._new_seed()
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.') print(
'Partial results will be returned; if --grid was requested, nothing will be returned.'
)
except RuntimeError as e: except RuntimeError as e:
print(str(e)) print(str(e))
print('Are you sure your system has an adequate NVIDIA GPU?') print('Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time() toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) print(f'{len(results)} images generated in', '%4.2fs' % (toc - tic))
return results return results
@torch.no_grad() @torch.no_grad()
def _txt2img(self, def _txt2img(
prompt, self,
precision_scope, prompt,
batch_size, precision_scope,
steps,cfg_scale,ddim_eta, batch_size,
skip_normalize, steps,
width,height): cfg_scale,
ddim_eta,
skip_normalize,
width,
height,
):
""" """
An infinite iterator of images from the prompt. An infinite iterator of images from the prompt.
""" """
sampler = self.sampler sampler = self.sampler
while True: while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [
samples, _ = sampler.sample(S=steps, self.latent_channels,
conditioning=c, height // self.downsampling_factor,
batch_size=batch_size, width // self.downsampling_factor,
shape=shape, ]
verbose=False, samples, _ = sampler.sample(
unconditional_guidance_scale=cfg_scale, S=steps,
unconditional_conditioning=uc, conditioning=c,
eta=ddim_eta) batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
)
yield self._samples_to_images(samples) yield self._samples_to_images(samples)
@torch.no_grad() @torch.no_grad()
def _img2img(self, def _img2img(
prompt, self,
precision_scope, prompt,
batch_size, precision_scope,
steps,cfg_scale,ddim_eta, batch_size,
skip_normalize, steps,
init_img,strength): cfg_scale,
ddim_eta,
skip_normalize,
init_img,
strength,
):
""" """
An infinite iterator of images from the prompt and the initial image An infinite iterator of images from the prompt and the initial image
""" """
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim': if self.sampler_name != 'ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") print(
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
)
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
else: else:
sampler = self.sampler sampler = self.sampler
@ -351,9 +398,13 @@ The vast majority of these arguments default to reasonable values.
init_image = self._load_img(init_img).to(self.device) init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope(self.device.type): with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
t_enc = int(strength * steps) t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps") # print(f"target t_enc is {t_enc} steps")
@ -362,31 +413,44 @@ The vast majority of these arguments default to reasonable values.
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(self.device)
)
# decode it # decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, samples = sampler.decode(
unconditional_conditioning=uc,) z_enc,
c,
t_enc,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
yield self._samples_to_images(samples) yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed? # TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, batch_size, skip_normalize): def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
uc = self.model.get_learned_conditioning(batch_size * [""]) uc = self.model.get_learned_conditioning(batch_size * [''])
# weighted sub-prompts # weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompt) subprompts, weights = T2I._split_weighted_subprompts(prompt)
if len(subprompts) > 1: if len(subprompts) > 1:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# get total weight for normalizing # get total weight for normalizing
totalWeight = sum(weights) totalWeight = sum(weights)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for i in range(0,len(subprompts)): for i in range(0, len(subprompts)):
weight = weights[i] weight = weights[i]
if not skip_normalize: if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight) c = torch.add(
else: # just standard 1 prompt c,
self.model.get_learned_conditioning(
batch_size * [subprompts[i]]
),
alpha=weight,
)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(batch_size * [prompt]) c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c) return (uc, c)
@ -395,23 +459,29 @@ The vast majority of these arguments default to reasonable values.
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list() images = list()
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), 'c h w -> h w c'
)
image = Image.fromarray(x_sample.astype(np.uint8)) image = Image.fromarray(x_sample.astype(np.uint8))
images.append(image) images.append(image)
return images return images
def _new_seed(self): def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max) self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed return self.seed
def load_model(self): def load_model(self):
""" Load and initialize the model from configuration variables passed at object creation time """ """Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None: if self.model is None:
seed_everything(self.seed) seed_everything(self.seed)
try: try:
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu") self.device = (
model = self._load_model_from_config(config,self.weights) torch.device(self.device)
if torch.cuda.is_available()
else torch.device('cpu')
)
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None: if self.embedding_path is not None:
model.embedding_manager.load(self.embedding_path) model.embedding_manager.load(self.embedding_path)
self.model = model.to(self.device) self.model = model.to(self.device)
@ -421,18 +491,26 @@ The vast majority of these arguments default to reasonable values.
raise SystemExit raise SystemExit
msg = f'setting sampler to {self.sampler_name}' msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms': if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device) self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device) self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a': elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2': elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device) self.sampler = KSampler(
self.model, 'dpm_2', device=self.device
)
elif self.sampler_name == 'k_euler_a': elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler': elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device) self.sampler = KSampler(
self.model, 'euler', device=self.device
)
elif self.sampler_name == 'k_heun': elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device) self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms': elif self.sampler_name == 'k_lms':
@ -446,32 +524,38 @@ The vast majority of these arguments default to reasonable values.
return self.model return self.model
def _load_model_from_config(self, config, ckpt): def _load_model_from_config(self, config, ckpt):
print(f"Loading model from {ckpt}") print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location='cpu')
# if "global_step" in pl_sd: # if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}") # print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] sd = pl_sd['state_dict']
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.to(self.device) model.to(self.device)
model.eval() model.eval()
if self.full_precision: if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)') print(
'Using slower but more accurate full-precision math (--full_precision)'
)
else: else:
print('Using half precision math. Call with --full_precision to use slower but more accurate full precision.') print(
'Using half precision math. Call with --full_precision to use slower but more accurate full precision.'
)
model.half() model.half()
return model return model
def _load_img(self,path): def _load_img(self, path):
image = Image.open(path).convert("RGB") image = Image.open(path).convert('RGB')
w, h = image.size w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}") print(f'loaded input image of size ({w}, {h}) from {path}')
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(
lambda x: x - x % 32, (w, h)
) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS) image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.0 * image - 1.0
def _split_weighted_subprompts(text): def _split_weighted_subprompts(text):
""" """
@ -484,34 +568,36 @@ The vast majority of these arguments default to reasonable values.
prompts = [] prompts = []
weights = [] weights = []
while remaining > 0: while remaining > 0:
if ":" in text: if ':' in text:
idx = text.index(":") # first occurrence from start idx = text.index(':') # first occurrence from start
# grab up to index as sub-prompt # grab up to index as sub-prompt
prompt = text[:idx] prompt = text[:idx]
remaining -= idx remaining -= idx
# remove from main text # remove from main text
text = text[idx+1:] text = text[idx + 1 :]
# find value for weight # find value for weight
if " " in text: if ' ' in text:
idx = text.index(" ") # first occurence idx = text.index(' ') # first occurence
else: # no space, read to end else: # no space, read to end
idx = len(text) idx = len(text)
if idx != 0: if idx != 0:
try: try:
weight = float(text[:idx]) weight = float(text[:idx])
except: # couldn't treat as float except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") print(
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
)
weight = 1.0 weight = 1.0
else: # no value found else: # no value found
weight = 1.0 weight = 1.0
# remove from main text # remove from main text
remaining -= idx remaining -= idx
text = text[idx+1:] text = text[idx + 1 :]
# append the sub-prompt and its weight # append the sub-prompt and its weight
prompts.append(prompt) prompts.append(prompt)
weights.append(weight) weights.append(weight)
else: # no : found else: # no : found
if len(text) > 0: # there is still text though if len(text) > 0: # there is still text though
# take remainder as weight 1 # take remainder as weight 1
prompts.append(text) prompts.append(text)
weights.append(1.0) weights.append(1.0)
@ -519,13 +605,20 @@ The vast majority of these arguments default to reasonable values.
return prompts, weights return prompts, weights
def _run_gfpgan(self, image, strength): def _run_gfpgan(self, image, strength):
if (self.gfpgan is None): if self.gfpgan is None:
print(f"GFPGAN not initialized, it must be loaded via the --gfpgan argument") print(
f'GFPGAN not initialized, it must be loaded via the --gfpgan argument'
)
return image return image
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if strength < 1.0: if strength < 1.0:

View File

@ -13,22 +13,25 @@ from queue import Queue
from inspect import isfunction from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot
b = len(xc) b = len(xc)
txts = list() txts = list()
for bi in range(b): for bi in range(b):
txt = Image.new("RGB", wh, color="white") txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt) draw = ImageDraw.Draw(txt)
font = ImageFont.load_default() font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256)) nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) lines = '\n'.join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try: try:
draw.text((0, 0), lines, fill="black", font=font) draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError: except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.") print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt) txts.append(txt)
@ -70,22 +73,26 @@ def mean_flat(tensor):
def count_params(model, verbose=False): def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
if verbose: if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params return total_params
def instantiate_from_config(config, **kwargs): def instantiate_from_config(config, **kwargs):
if not "target" in config: if not 'target' in config:
if config == '__is_first_stage__': if config == '__is_first_stage__':
return None return None
elif config == "__is_unconditional__": elif config == '__is_unconditional__':
return None return None
raise KeyError("Expected key `target` to instantiate.") raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
def get_obj_from_str(string, reload=False): def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1) module, cls = string.rsplit('.', 1)
if reload: if reload:
module_imp = importlib.import_module(module) module_imp = importlib.import_module(module)
importlib.reload(module_imp) importlib.reload(module_imp)
@ -101,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else: else:
res = func(data) res = func(data)
Q.put([idx, res]) Q.put([idx, res])
Q.put("Done") Q.put('Done')
def parallel_data_prefetch( def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False func: callable,
data,
n_proc,
target_data_type='ndarray',
cpu_intensive=True,
use_worker_id=False,
): ):
# if target_data_type not in ["ndarray", "list"]: # if target_data_type not in ["ndarray", "list"]:
# raise ValueError( # raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# ) # )
if isinstance(data, np.ndarray) and target_data_type == "list": if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError("list expected but function got ndarray.") raise ValueError('list expected but function got ndarray.')
elif isinstance(data, abc.Iterable): elif isinstance(data, abc.Iterable):
if isinstance(data, dict): if isinstance(data, dict):
print( print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
) )
data = list(data.values()) data = list(data.values())
if target_data_type == "ndarray": if target_data_type == 'ndarray':
data = np.asarray(data) data = np.asarray(data)
else: else:
data = list(data) data = list(data)
else: else:
raise TypeError( raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
) )
if cpu_intensive: if cpu_intensive:
@ -135,7 +147,7 @@ def parallel_data_prefetch(
Q = Queue(1000) Q = Queue(1000)
proc = Thread proc = Thread
# spawn processes # spawn processes
if target_data_type == "ndarray": if target_data_type == 'ndarray':
arguments = [ arguments = [
[func, Q, part, i, use_worker_id] [func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc)) for i, part in enumerate(np.array_split(data, n_proc))
@ -149,7 +161,7 @@ def parallel_data_prefetch(
arguments = [ arguments = [
[func, Q, part, i, use_worker_id] [func, Q, part, i, use_worker_id]
for i, part in enumerate( for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)] [data[i : i + step] for i in range(0, len(data), step)]
) )
] ]
processes = [] processes = []
@ -158,7 +170,7 @@ def parallel_data_prefetch(
processes += [p] processes += [p]
# start processes # start processes
print(f"Start prefetching...") print(f'Start prefetching...')
import time import time
start = time.time() start = time.time()
@ -171,13 +183,13 @@ def parallel_data_prefetch(
while k < n_proc: while k < n_proc:
# get result # get result
res = Q.get() res = Q.get()
if res == "Done": if res == 'Done':
k += 1 k += 1
else: else:
gather_res[res[0]] = res[1] gather_res[res[0]] = res[1]
except Exception as e: except Exception as e:
print("Exception: ", e) print('Exception: ', e)
for p in processes: for p in processes:
p.terminate() p.terminate()
@ -185,7 +197,7 @@ def parallel_data_prefetch(
finally: finally:
for p in processes: for p in processes:
p.join() p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]") print(f'Prefetching complete. [{time.time() - start} sec.]')
if target_data_type == 'ndarray': if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray): if not isinstance(gather_res[0], np.ndarray):

685
main.py

File diff suppressed because it is too large Load Diff

View File

@ -8,62 +8,66 @@ import sys
import copy import copy
import warnings import warnings
import ldm.dream.readline import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter,PromptFormatter from ldm.dream.pngwriter import PngWriter, PromptFormatter
debugging = False debugging = False
def main(): def main():
''' Initialize command-line parsers and the diffusion model ''' """Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser() arg_parser = create_argv_parser()
opt = arg_parser.parse_args() opt = arg_parser.parse_args()
if opt.laion400m: if opt.laion400m:
# defaults suitable to the older latent diffusion weights # defaults suitable to the older latent diffusion weights
width = 256 width = 256
height = 256 height = 256
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml'
weights = "models/ldm/text2img-large/model.ckpt" weights = 'models/ldm/text2img-large/model.ckpt'
else: else:
# some defaults suitable for stable diffusion weights # some defaults suitable for stable diffusion weights
width = 512 width = 512
height = 512 height = 512
config = "configs/stable-diffusion/v1-inference.yaml" config = 'configs/stable-diffusion/v1-inference.yaml'
weights = "models/ldm/stable-diffusion-v1/model.ckpt" weights = 'models/ldm/stable-diffusion-v1/model.ckpt'
print("* Initializing, be patient...\n") print('* Initializing, be patient...\n')
sys.path.append('.') sys.path.append('.')
from pytorch_lightning import logging from pytorch_lightning import logging
from ldm.simplet2i import T2I from ldm.simplet2i import T2I
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
import transformers import transformers
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# creating a simple text2image object with a handful of # creating a simple text2image object with a handful of
# defaults passed on the command line. # defaults passed on the command line.
# additional parameters will be added (or overriden) during # additional parameters will be added (or overriden) during
# the user input loop # the user input loop
t2i = T2I(width=width, t2i = T2I(
height=height, width=width,
sampler_name=opt.sampler_name, height=height,
weights=weights, sampler_name=opt.sampler_name,
full_precision=opt.full_precision, weights=weights,
config=config, full_precision=opt.full_precision,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt config=config,
embedding_path=opt.embedding_path, latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
device=opt.device embedding_path=opt.embedding_path,
device=opt.device,
) )
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed # gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
infile = None infile = None
try: try:
if opt.infile is not None: if opt.infile is not None:
infile = open(opt.infile,'r') infile = open(opt.infile, 'r')
except FileNotFoundError as e: except FileNotFoundError as e:
print(e) print(e)
exit(-1) exit(-1)
@ -73,135 +77,156 @@ def main():
# load GFPGAN if requested # load GFPGAN if requested
if opt.use_gfpgan: if opt.use_gfpgan:
print("\n* --gfpgan was specified, loading gfpgan...") print('\n* --gfpgan was specified, loading gfpgan...')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
try: try:
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path) model_path = os.path.join(
opt.gfpgan_dir, opt.gfpgan_model_path
)
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path) raise Exception(
'GFPGAN model not found at path ' + model_path
)
sys.path.append(os.path.abspath(opt.gfpgan_dir)) sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer from gfpgan import GFPGANer
bg_upsampler = load_gfpgan_bg_upsampler(opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile) bg_upsampler = load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile
)
t2i.gfpgan = GFPGANer(model_path=model_path, upscale=opt.gfpgan_upscale, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler) t2i.gfpgan = GFPGANer(
model_path=model_path,
upscale=opt.gfpgan_upscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception: except Exception:
import traceback import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print('Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...") print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..."
)
log_path = os.path.join(opt.outdir,'dream_log.txt') log_path = os.path.join(opt.outdir, 'dream_log.txt')
with open(log_path,'a') as log: with open(log_path, 'a') as log:
cmd_parser = create_cmd_parser() cmd_parser = create_cmd_parser()
main_loop(t2i,opt.outdir,cmd_parser,log,infile) main_loop(t2i, opt.outdir, cmd_parser, log, infile)
log.close() log.close()
if infile: if infile:
infile.close() infile.close()
def main_loop(t2i,outdir,parser,log,infile): def main_loop(t2i, outdir, parser, log, infile):
''' prompt/read/execute loop ''' """prompt/read/execute loop"""
done = False done = False
last_seeds = [] last_seeds = []
while not done: while not done:
try: try:
command = infile.readline() if infile else input("dream> ") command = infile.readline() if infile else input('dream> ')
except EOFError: except EOFError:
done = True done = True
break break
if infile and len(command)==0: if infile and len(command) == 0:
done = True done = True
break break
if command.startswith(('#','//')): if command.startswith(('#', '//')):
continue continue
# before splitting, escape single quotes so as not to mess # before splitting, escape single quotes so as not to mess
# up the parser # up the parser
command = command.replace("'","\\'") command = command.replace("'", "\\'")
try: try:
elements = shlex.split(command) elements = shlex.split(command)
except ValueError as e: except ValueError as e:
print(str(e)) print(str(e))
continue continue
if len(elements)==0: if len(elements) == 0:
continue continue
if elements[0]=='q': if elements[0] == 'q':
done = True done = True
break break
if elements[0]=='cd' and len(elements)>1: if elements[0] == 'cd' and len(elements) > 1:
if os.path.exists(elements[1]): if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}") print(f'setting image output directory to {elements[1]}')
outdir=elements[1] outdir = elements[1]
else: else:
print(f"directory {elements[1]} does not exist") print(f'directory {elements[1]} does not exist')
continue continue
if elements[0]=='pwd': if elements[0] == 'pwd':
print(f"current output directory is {outdir}") print(f'current output directory is {outdir}')
continue continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command if elements[0].startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
elements.pop(0) elements.pop(0)
# rearrange the arguments to mimic how it works in the Dream bot. # rearrange the arguments to mimic how it works in the Dream bot.
switches = [''] switches = ['']
switches_started = False switches_started = False
for el in elements: for el in elements:
if el[0]=='-' and not switches_started: if el[0] == '-' and not switches_started:
switches_started = True switches_started = True
if switches_started: if switches_started:
switches.append(el) switches.append(el)
else: else:
switches[0] += el switches[0] += el
switches[0] += ' ' switches[0] += ' '
switches[0] = switches[0][:len(switches[0])-1] switches[0] = switches[0][: len(switches[0]) - 1]
try: try:
opt = parser.parse_args(switches) opt = parser.parse_args(switches)
except SystemExit: except SystemExit:
parser.print_help() parser.print_help()
continue continue
if len(opt.prompt)==0: if len(opt.prompt) == 0:
print("Try again with a prompt!") print('Try again with a prompt!')
continue continue
if opt.seed is not None and opt.seed<0: # retrieve previous value! if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try: try:
opt.seed = last_seeds[opt.seed] opt.seed = last_seeds[opt.seed]
print(f"reusing previous seed {opt.seed}") print(f'reusing previous seed {opt.seed}')
except IndexError: except IndexError:
print(f"No previous seed at position {opt.seed} found") print(f'No previous seed at position {opt.seed} found')
opt.seed = None opt.seed = None
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt() normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
individual_images = not opt.grid individual_images = not opt.grid
try: try:
file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size) file_writer = PngWriter(outdir, normalized_prompt, opt.batch_size)
callback = file_writer.write_image if individual_images else None callback = file_writer.write_image if individual_images else None
image_list = t2i.prompt2image(image_callback=callback,**vars(opt)) image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
results = file_writer.files_written if individual_images else image_list results = (
file_writer.files_written if individual_images else image_list
)
if opt.grid and len(results) > 0: if opt.grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results]) grid_img = file_writer.make_grid([r[0] for r in results])
filename = file_writer.unique_filename(results[0][1]) filename = file_writer.unique_filename(results[0][1])
seeds = [a[1] for a in results] seeds = [a[1] for a in results]
results = [[filename,seeds]] results = [[filename, seeds]]
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}' metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename) file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)
last_seeds = [r[1] for r in results] last_seeds = [r[1] for r in results]
@ -213,10 +238,11 @@ def main_loop(t2i,outdir,parser,log,infile):
print(e) print(e)
continue continue
print("Outputs:") print('Outputs:')
write_log_message(t2i,normalized_prompt,results,log) write_log_message(t2i, normalized_prompt, results, log)
print('goodbye!')
print("goodbye!")
def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400): def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
import torch import torch
@ -224,13 +250,24 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
if bg_upsampler == 'realesrgan': if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU if not torch.cuda.is_available(): # CPU
import warnings import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.') warnings.warn(
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.'
)
bg_upsampler = None bg_upsampler = None
else: else:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer( bg_upsampler = RealESRGANer(
scale=2, scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
@ -238,12 +275,14 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
tile=bg_tile, tile=bg_tile,
tile_pad=10, tile_pad=10,
pre_pad=0, pre_pad=0,
half=True) # need to set False in CPU mode half=True,
) # need to set False in CPU mode
else: else:
bg_upsampler = None bg_upsampler = None
return bg_upsampler return bg_upsampler
# variant generation is going to be superseded by a generalized # variant generation is going to be superseded by a generalized
# "prompt-morph" functionality # "prompt-morph" functionality
# def generate_variants(t2i,outdir,opt,previous_gens): # def generate_variants(t2i,outdir,opt,previous_gens):
@ -268,110 +307,209 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
# continue # continue
# print(f'{opt.variants} variants generated') # print(f'{opt.variants} variants generated')
# return variants # return variants
def write_log_message(t2i,prompt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata''' def write_log_message(t2i, prompt, results, logfile):
last_seed = None """logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata"""
img_num = 1 last_seed = None
seenit = {} img_num = 1
seenit = {}
for r in results: for r in results:
seed = r[1] seed = r[1]
log_message = (f'{r[0]}: {prompt} -S{seed}') log_message = f'{r[0]}: {prompt} -S{seed}'
print(log_message) print(log_message)
logfile.write(log_message+"\n") logfile.write(log_message + '\n')
logfile.flush() logfile.flush()
def create_argv_parser(): def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args") parser = argparse.ArgumentParser(
parser.add_argument("--laion400m", description="Parse script's command line args"
"--latent_diffusion", )
"-l", parser.add_argument(
dest='laion400m', '--laion400m',
action='store_true', '--latent_diffusion',
help="fallback to the latent diffusion (laion400m) weights and config") '-l',
parser.add_argument("--from_file", dest='laion400m',
dest='infile', action='store_true',
type=str, help='fallback to the latent diffusion (laion400m) weights and config',
help="if specified, load prompts from this file") )
parser.add_argument('-n','--iterations', parser.add_argument(
type=int, '--from_file',
default=1, dest='infile',
help="number of images to generate") type=str,
parser.add_argument('-F','--full_precision', help='if specified, load prompts from this file',
dest='full_precision', )
action='store_true', parser.add_argument(
help="use slower full precision math for calculations") '-n',
parser.add_argument('--sampler','-m', '--iterations',
dest="sampler_name", type=int,
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], default=1,
default='k_lms', help='number of images to generate',
help="which sampler to use (k_lms) - can only be set on command line") )
parser.add_argument('--outdir', parser.add_argument(
'-o', '-F',
type=str, '--full_precision',
default="outputs/img-samples", dest='full_precision',
help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples") action='store_true',
parser.add_argument('--embedding_path', help='use slower full precision math for calculations',
type=str, )
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line") parser.add_argument(
parser.add_argument('--device', '--sampler',
'-d', '-m',
type=str, dest='sampler_name',
default="cuda", choices=[
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible") 'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
],
default='k_lms',
help='which sampler to use (k_lms) - can only be set on command line',
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default='outputs/img-samples',
help='directory in which to place generated images and a log of prompts and seeds (outputs/img-samples',
)
parser.add_argument(
'--embedding_path',
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
parser.add_argument(
'--device',
'-d',
type=str,
default='cuda',
help='device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible',
)
# GFPGAN related args # GFPGAN related args
parser.add_argument('--gfpgan', parser.add_argument(
dest='use_gfpgan', '--gfpgan',
action='store_true', dest='use_gfpgan',
help="load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory") action='store_true',
parser.add_argument("--gfpgan_upscale", help='load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory',
type=int, )
default=2, parser.add_argument(
help="The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified") '--gfpgan_upscale',
parser.add_argument("--gfpgan_bg_upsampler", type=int,
type=str, default=2,
default='realesrgan', help='The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified',
help="Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified") )
parser.add_argument("--gfpgan_bg_tile", parser.add_argument(
type=int, '--gfpgan_bg_upsampler',
default=400, type=str,
help="Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified") default='realesrgan',
parser.add_argument("--gfpgan_model_path", help='Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified',
type=str, )
default='experiments/pretrained_models/GFPGANv1.3.pth', parser.add_argument(
help="indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified") '--gfpgan_bg_tile',
parser.add_argument("--gfpgan_dir", type=int,
type=str, default=400,
default='../GFPGAN', help='Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified',
help="indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified") )
parser.add_argument(
'--gfpgan_model_path',
type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth',
help='indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_dir',
type=str,
default='../GFPGAN',
help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified',
)
return parser return parser
def create_cmd_parser(): def create_cmd_parser():
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12') parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
parser.add_argument('prompt') parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-s', '--steps', type=int, help='number of steps')
parser.add_argument('-S','--seed',type=int,help="image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc") parser.add_argument(
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform (slower, but will provide seeds for individual images)") '-S',
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (will not provide seeds for individual images!)") '--seed',
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") type=int,
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") help='image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale") )
parser.add_argument('-g','--grid',action='store_true',help="generate a grid") parser.add_argument(
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") '-n',
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)") '--iterations',
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") type=int,
parser.add_argument('-G','--gfpgan_strength', default=0.5, type=float, help="The strength at which to apply the GFPGAN model to the result, in order to improve faces.") default=1,
# variants is going to be superseded by a generalized "prompt-morph" function help='number of samplings to perform (slower, but will provide seeds for individual images)',
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") )
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") parser.add_argument(
'-b',
'--batch_size',
type=int,
default=1,
help='number of images to produce per sampling (will not provide seeds for individual images!)',
)
parser.add_argument(
'-W', '--width', type=int, help='image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='prompt configuration scale',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0.5,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='skip subprompt weight normalization',
)
return parser return parser
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -11,26 +11,28 @@ import warnings
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
print("preloading bert tokenizer...") print('preloading bert tokenizer...')
from transformers import BertTokenizerFast from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success") tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
# this will download requirements for Kornia # this will download requirements for Kornia
print("preloading Kornia requirements (ignore the deprecation warnings)...") print('preloading Kornia requirements (ignore the deprecation warnings)...')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia import kornia
print("...success") print('...success')
version='openai/clip-vit-large-patch14' version = 'openai/clip-vit-large-patch14'
print('preloading CLIP model (Ignore the deprecation warnings)...') print('preloading CLIP model (Ignore the deprecation warnings)...')
sys.stdout.flush() sys.stdout.flush()
import clip import clip
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
tokenizer =CLIPTokenizer.from_pretrained(version)
transformer=CLIPTextModel.from_pretrained(version) tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('\n\n...success') print('\n\n...success')
# In the event that the user has installed GFPGAN and also elected to use # In the event that the user has installed GFPGAN and also elected to use
@ -38,23 +40,33 @@ print('\n\n...success')
gfpgan = False gfpgan = False
try: try:
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
gfpgan = True gfpgan = True
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
if gfpgan: if gfpgan:
print("Loading models from RealESRGAN and facexlib") print('Loading models from RealESRGAN and facexlib')
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', RealESRGANer(
model=RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)) scale=2,
FaceRestoreHelper(1,det_model='retinaface_resnet50') model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
print("...success") model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success')
except Exception: except Exception:
import traceback import traceback
print("Error loading GFPGAN:")
print('Error loading GFPGAN:')
print(traceback.format_exc()) print(traceback.format_exc())