prettified all the code using "blue" at the urging of @tildebyte

This commit is contained in:
Lincoln Stein 2022-08-26 03:15:42 -04:00
parent dd670200bb
commit 4f02b72c9c
35 changed files with 6252 additions and 3119 deletions

View File

@ -1,11 +1,17 @@
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from torch.utils.data import (
Dataset,
ConcatDataset,
ChainDataset,
IterableDataset,
)
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
"""
Define an interface to make the IterableDatasets for text2img data chainable
'''
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
@ -13,7 +19,9 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self):
return self.num_records

View File

@ -11,24 +11,34 @@ from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import (
str_to_indices,
give_synsets_from_indices,
download,
retrieve,
)
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
from ldm.modules.image_degradation import (
degradation_fn_bsr,
degradation_fn_bsr_light,
)
def synset2idx(path_to_yaml="data/index_synset.yaml"):
def synset2idx(path_to_yaml='data/index_synset.yaml'):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.keep_orig_class_label = self.config.get(
'keep_orig_class_label', False
)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
@ -46,17 +56,23 @@ class ImageNetBase(Dataset):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
ignore = set(
[
'n06596364_9591.JPEG',
]
)
relpaths = [
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
]
if 'sub_indices' in self.config:
indices = str_to_indices(self.config['sub_indices'])
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
syn = rpath.split('/')[0]
if syn in synsets:
files.append(rpath)
return files
@ -65,64 +81,75 @@ class ImageNetBase(Dataset):
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
self.human_dict = os.path.join(self.root, 'synset_human.txt')
if (
not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
self.human2integer = os.path.join(
self.root, 'imagenet1000_clsidx_to_labels.txt'
)
if not os.path.exists(self.human2integer):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
with open(self.human2integer, 'r') as f:
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
value, key = line.split(":")
value, key = line.split(':')
self.human2integer_dict[key] = int(value)
def _load(self):
with open(self.txt_filelist, "r") as f:
with open(self.txt_filelist, 'r') as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
print(
'Removed {} files from filelist during filtering.'.format(
l1 - len(self.relpaths)
)
)
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.synsets = [p.split('/')[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
class_dict = dict(
(synset, i) for i, synset in enumerate(unique_synsets)
)
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
with open(self.human_dict, 'r') as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
'relpath': np.array(self.relpaths),
'synsets': np.array(self.synsets),
'class_label': np.array(self.class_labels),
'human_label': np.array(self.human_labels),
}
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
self.size = retrieve(self.config, 'size', default=256)
self.data = ImagePaths(
self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
@ -132,11 +159,11 @@ class ImageNetBase(Dataset):
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
NAME = 'ILSVRC2012_train'
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
FILES = [
"ILSVRC2012_img_train.tar",
'ILSVRC2012_img_train.tar',
]
SIZES = [
147897477120,
@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
cachedir = os.environ.get(
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
self.random_crop = retrieve(
self.config, 'ImageNetTrain/random_crop', default=True
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
print('Extracting sub-tars.')
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
subdir = subpath[: -len('.tar')]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
with tarfile.open(subpath, 'r:') as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, 'w') as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
NAME = 'ILSVRC2012_validation'
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
'ILSVRC2012_img_val.tar',
'validation_synset.txt',
]
SIZES = [
6744924160,
@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
cachedir = os.environ.get(
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
self.random_crop = retrieve(
self.config, 'ImageNetValidation/random_crop', default=False
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
with open(vspath, 'r') as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
print('Reorganizing into synset folders')
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase):
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, 'w') as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
def __init__(
self,
size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
@ -296,67 +345,86 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
assert max_crop_f <= 1.0
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
if degradation == 'bsrgan':
self.degradation_process = partial(
degradation_fn_bsr, sf=downscale_f
)
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
elif degradation == 'bsrgan_light':
self.degradation_process = partial(
degradation_fn_bsr_light, sf=downscale_f
)
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
'cv_nearest': cv2.INTER_NEAREST,
'cv_bilinear': cv2.INTER_LINEAR,
'cv_bicubic': cv2.INTER_CUBIC,
'cv_area': cv2.INTER_AREA,
'cv_lanczos': cv2.INTER_LANCZOS4,
'pil_nearest': PIL.Image.NEAREST,
'pil_bilinear': PIL.Image.BILINEAR,
'pil_bicubic': PIL.Image.BICUBIC,
'pil_box': PIL.Image.BOX,
'pil_hamming': PIL.Image.HAMMING,
'pil_lanczos': PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
self.pil_interpolation = degradation.startswith('pil_')
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
self.degradation_process = partial(
TF.resize,
size=self.LR_size,
interpolation=interpolation_fn,
)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
self.degradation_process = albumentations.SmallestMaxSize(
max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = Image.open(example["file_path_"])
image = Image.open(example['file_path_'])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
image = self.cropper(image=image)['image']
image = self.image_rescaler(image=image)['image']
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
@ -364,10 +432,10 @@ class ImageNetSR(Dataset):
LR_image = np.array(LR_image).astype(np.uint8)
else:
LR_image = self.degradation_process(image=image)["image"]
LR_image = self.degradation_process(image=image)['image']
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example
@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
with open('data/imagenet_train_hr_indices.p', 'rb') as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices)
@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
with open('data/imagenet_val_hr_indices.p', 'rb') as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices)

View File

@ -7,29 +7,32 @@ from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
def __init__(
self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
interpolation='bicubic',
flip_p=0.5,
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
with open(self.data_paths, 'r') as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
'relative_file_path_': [l for l in self.image_paths],
'file_path_': [
os.path.join(self.data_root, l) for l in self.image_paths
],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -38,55 +41,86 @@ class LSUNBase(Dataset):
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = Image.open(example['file_path_'])
if not image.mode == 'RGB':
image = image.convert('RGB')
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
super().__init__(
txt_file='data/lsun/church_outdoor_train.txt',
data_root='data/lsun/churches',
**kwargs
)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file='data/lsun/church_outdoor_val.txt',
data_root='data/lsun/churches',
flip_p=flip_p,
**kwargs
)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
super().__init__(
txt_file='data/lsun/bedrooms_train.txt',
data_root='data/lsun/bedrooms',
**kwargs
)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
super().__init__(
txt_file='data/lsun/bedrooms_val.txt',
data_root='data/lsun/bedrooms',
flip_p=flip_p,
**kwargs
)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
super().__init__(
txt_file='data/lsun/cat_train.txt',
data_root='data/lsun/cats',
**kwargs
)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file='data/lsun/cat_val.txt',
data_root='data/lsun/cats',
flip_p=flip_p,
**kwargs
)

View File

@ -72,18 +72,41 @@ imagenet_dual_templates_small = [
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
interpolation='bicubic',
flip_p=0.5,
set="train",
placeholder_token="*",
set='train',
placeholder_token='*',
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
@ -92,7 +115,10 @@ class PersonalizedBase(Dataset):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
@ -107,16 +133,19 @@ class PersonalizedBase(Dataset):
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
if set == 'train':
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -127,34 +156,47 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
placeholder_string = (
f'{self.coarse_class_text} {placeholder_string}'
)
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images])
text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(placeholder_string)
text = random.choice(imagenet_templates_small).format(
placeholder_string
)
example["caption"] = text
example['caption'] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -50,25 +50,51 @@ imagenet_dual_templates_small = [
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
interpolation='bicubic',
flip_p=0.5,
set="train",
placeholder_token="*",
set='train',
placeholder_token='*',
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
@ -80,16 +106,19 @@ class PersonalizedBase(Dataset):
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
if set == 'train':
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -100,30 +129,41 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
text = random.choice(imagenet_dual_templates_small).format(
self.placeholder_token, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
text = random.choice(imagenet_templates_small).format(
self.placeholder_token
)
example["caption"] = text
example['caption'] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -1,4 +1,4 @@
'''
"""
Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
@ -7,16 +7,15 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
prompt for file/directory names.
PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches.
'''
"""
import os
import re
from math import sqrt,floor,ceil
from PIL import Image,PngImagePlugin
from math import sqrt, floor, ceil
from PIL import Image, PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
def __init__(self,outdir,prompt=None,batch_size=1):
def __init__(self, outdir, prompt=None, batch_size=1):
self.outdir = outdir
self.batch_size = batch_size
self.prompt = prompt
@ -24,36 +23,41 @@ class PngWriter:
self.files_written = []
os.makedirs(outdir, exist_ok=True)
def write_image(self,image,seed):
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
def write_image(self, image, seed):
self.filepath = self.unique_filename(
seed, self.filepath
) # will increment name in some sensible way
try:
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
except IOError as e:
print(e)
self.files_written.append([self.filepath,seed])
self.files_written.append([self.filepath, seed])
def unique_filename(self,seed,previouspath=None):
def unique_filename(self, seed, previouspath=None):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir),reverse=True)
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
filename = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(filename.split('.', 1)[0])
basecount += 1
if self.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir,filename)
return os.path.join(self.outdir, filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename)
x = re.match('^(\d+)\..*\.png', basename)
if not x:
return self.unique_filename(seed,previouspath)
return self.unique_filename(seed, previouspath)
basecount = int(x.groups()[0])
series = 0
@ -61,39 +65,44 @@ class PngWriter:
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
if self.batch_size > 1 or os.path.exists(
os.path.join(self.outdir, filename)
):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(self.outdir,filename))
return os.path.join(self.outdir,filename)
finished = not os.path.exists(
os.path.join(self.outdir, filename)
)
return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self,image,prompt,path):
def save_image_and_prompt_to_png(self, image, prompt, path):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
image.save(path,"PNG",pnginfo=info)
info.add_text('Dream', prompt)
image.save(path, 'PNG', pnginfo=info)
def make_grid(self,image_list,rows=None,cols=None):
def make_grid(self, image_list, rows=None, cols=None):
image_cnt = len(image_list)
if None in (rows,cols):
if None in (rows, cols):
rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt/rows)
cols = ceil(image_cnt / rows)
width = image_list[0].width
height = image_list[0].height
grid_img = Image.new('RGB',(width*cols,height*rows))
for r in range(0,rows):
for c in range (0,cols):
i = r*rows + c
grid_img.paste(image_list[i],(c*width,r*height))
grid_img = Image.new('RGB', (width * cols, height * rows))
for r in range(0, rows):
for c in range(0, cols):
i = r * rows + c
grid_img.paste(image_list[i], (c * width, r * height))
return grid_img
class PromptFormatter():
def __init__(self,t2i,opt):
class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i
self.opt = opt
def normalize_prompt(self):
'''Normalize the prompt and switches'''
"""Normalize the prompt and switches"""
t2i = self.t2i
opt = self.opt
@ -114,4 +123,3 @@ class PromptFormatter():
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)

View File

@ -1,37 +1,40 @@
'''
"""
Readline helper functions for dream.py (linux and mac only).
'''
"""
import os
import re
import atexit
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except:
readline_available = False
class Completer():
def __init__(self,options):
class Completer:
def __init__(self, options):
self.options = sorted(options)
return
def complete(self,text,state):
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if text.startswith(('-I', '--init_img')):
return self._path_completions(text, state, ('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
return self._path_completions(text, state, ())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
self.matches = [
s for s in self.options if s and s.startswith(text)
]
else:
self.matches = self.options[:]
@ -43,32 +46,34 @@ class Completer():
response = None
return response
def _path_completions(self,text,state,extensions):
def _path_completions(self, text, state, extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
path = text.replace('-I', '', 1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
path = text.replace('--init_img=', '', 1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
if len(path) == 0:
matches.append(text + './')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
if n.startswith('.') and len(n) > 1:
continue
full_path = os.path.join(dir,n)
full_path = os.path.join(dir, n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
matches.append(os.path.join(os.path.dirname(text), n))
try:
response = matches[state]
@ -76,19 +81,47 @@ class Completer():
response = None
return response
if readline_available:
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.set_completer(
Completer(
[
'cd',
'pwd',
'--steps',
'-s',
'--seed',
'-S',
'--iterations',
'-n',
'--batch_size',
'-b',
'--width',
'-W',
'--height',
'-H',
'--cfg_scale',
'-C',
'--grid',
'-g',
'--individual',
'-i',
'--init_img',
'-I',
'--strength',
'-f',
'-v',
'--variants',
]
).complete
)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
atexit.register(readline.write_history_file, histfile)

View File

@ -5,32 +5,49 @@ class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_lr}'
)
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
def __init__(
self,
warm_up_steps,
f_min,
f_max,
f_start,
cycle_lengths,
verbosity_interval=0,
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2:
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
f = self.f_min[cycle] + 0.5 * (
self.f_max[cycle] - self.f_min[cycle]
) * (1 + np.cos(t * np.pi))
self.last_f = f
return f
@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2:
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

View File

@ -6,20 +6,23 @@ from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
image_key='image',
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
@ -27,7 +30,7 @@ class VQModel(pl.LightningModule):
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
@ -36,24 +39,34 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@ -66,28 +79,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
print(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
print(f'Missing Keys: {missing}')
print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
@ -115,7 +130,7 @@ class VQModel(pl.LightningModule):
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
@ -125,7 +140,11 @@ class VQModel(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
@ -133,9 +152,11 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach()
return x
@ -147,81 +168,139 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
log_dict_ema = self._validation_step(
batch, batch_idx, suffix='_ema'
)
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
split='val' + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
split='val' + suffix,
predicted_indices=ind,
)
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
lr_g = self.lr_g_factor * self.learning_rate
print('lr_d', lr_d)
print('lr_g', lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
opt_ae, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
'frequency': 1,
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
opt_disc, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
'frequency': 1,
},
]
return [opt_ae, opt_disc], scheduler
@ -235,7 +314,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
log['inputs'] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
@ -243,21 +322,24 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
log['inputs'] = x
log['reconstructions'] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
@ -283,13 +365,14 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule):
def __init__(self,
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
image_key='image',
colorize_nlabels=None,
monitor=None,
):
@ -298,28 +381,34 @@ class AutoencoderKL(pl.LightningModule):
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(
2 * ddconfig['z_channels'], 2 * embed_dim, 1
)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
print(f'Restored from {path}')
def encode(self, x):
h = self.encoder(x)
@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
return x
def training_step(self, batch, batch_idx, optimizer_idx):
@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
self.log(
'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x

View File

@ -10,13 +10,13 @@ from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
def disabled_train(self, mode=True):
@ -26,8 +26,8 @@ def disabled_train(self, mode=True):
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
@ -35,28 +35,40 @@ class NoisyLatentImageClassifier(pl.LightningModule):
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
weight_decay=1.0e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
diffusion_config = natsorted(
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.numd = (
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
)
self.log_time_interval = (
self.diffusion_model.num_timesteps // log_steps
)
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
self.label_key = (
label_key
if not hasattr(self.diffusion_model, 'cond_stage_key')
else self.diffusion_model.cond_stage_key
)
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
assert (
self.label_key is not None
), 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
sd = torch.load(path, map_location='cpu')
if 'state_dict' in list(sd.keys()):
sd = sd['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f'Missing Keys: {missing}')
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
print(f'Unexpected Keys: {unexpected}')
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule):
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config = deepcopy(
self.diffusion_config.params.unet_config.params
)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(
'#####################################################################'
)
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
print(
'#####################################################################'
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(
x.shape[0], t + 1
)
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule):
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
targets = F.interpolate(
targets, size=(h // 2, w // 2), mode='nearest'
)
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
def compute_top_k(self, logits, labels, k, reduction='mean'):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
if reduction == 'mean':
return (
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
)
elif reduction == 'none':
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val'
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
log[f'{log_prefix}/loss'] = loss.mean()
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
logits, targets, k=1, reduction='mean'
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
logits, targets, k=5, reduction='mean'
)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
self.log_dict(
log,
prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
self.log(
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
)
self.log(
'global_step',
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
self.log(
'lr_abs',
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
t = torch.full(
size=(x.shape[0],), fill_value=t, device=self.device
).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
return loss
def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
self.noisy_acc = {
t: {'acc@1': [], 'acc@5': []}
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self):
self.reset_noise_accs()
@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule):
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
self.noisy_acc[t]['acc@1'].append(
self.compute_top_k(logits, targets, k=1, reduction='mean')
)
self.noisy_acc[t]['acc@5'].append(
self.compute_top_k(logits, targets, k=5, reduction='mean')
)
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
optimizer, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
}]
'frequency': 1,
}
]
return [optimizer], scheduler
return optimizer
@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
y = self.get_conditioning(batch)
if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
log['labels'] = y
if ismap(y):
@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = F.one_hot(
logits.argmax(dim=1), num_classes=self.num_classes
)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
pred
)
for key in log:
log[key] = log[key][:N]

View File

@ -5,12 +5,16 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
@ -23,39 +27,87 @@ class DDIMSampler(object):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(self,
def sample(
self,
S,
batch_size,
shape,
@ -64,29 +116,33 @@ class DDIMSampler(object):
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
eta=0.0,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
@ -94,11 +150,14 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
@ -112,12 +171,26 @@ class DDIMSampler(object):
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@ -126,17 +199,38 @@ class DDIMSampler(object):
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True)
iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@ -144,18 +238,30 @@ class DDIMSampler(object):
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
@ -164,42 +270,82 @@ class DDIMSampler(object):
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
alphas = (
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@ -217,26 +363,51 @@ class DDIMSampler(object):
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
unconditional_conditioning=unconditional_conditioning,
)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,9 @@
'''wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers'''
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@ -15,8 +16,9 @@ class CFGDenoiser(nn.Module):
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self, model, schedule="lms", device="cuda", **kwargs):
def __init__(self, model, schedule='lms', device='cuda', **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
@ -26,14 +28,16 @@ class KSampler(object):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(self,
def sample(
self,
S,
batch_size,
shape,
@ -42,28 +46,39 @@ class KSampler(object):
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
eta=0.0,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
**kwargs,
):
sigmas = self.model.get_sigmas(S)
if x_T:
x = x_T
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args),
None)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
),
None,
)

View File

@ -5,11 +5,15 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
@ -23,41 +27,89 @@ class PLMSSampler(object):
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(self,
def sample(
self,
S,
batch_size,
shape,
@ -66,41 +118,48 @@ class PLMSSampler(object):
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
eta=0.0,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}')
# print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
@ -114,12 +173,26 @@ class PLMSSampler(object):
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@ -128,42 +201,81 @@ class PLMSSampler(object):
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running PLMS Sampling with {total_steps} timesteps")
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True)
iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
@ -172,47 +284,95 @@ class PLMSSampler(object):
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
e_t_uncond, e_t = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
alphas = (
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
a_prev = torch.full(
(b, 1, 1, 1), alphas_prev[index], device=device
)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t
* noise_like(x.shape, device, repeat_noise)
* temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@ -231,7 +391,12 @@ class PLMSSampler(object):
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
e_t_prime = (
55 * e_t
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

View File

@ -13,7 +13,7 @@ def exists(val):
def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()
def default(val, d):
@ -45,19 +45,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
@ -74,7 +73,9 @@ def zero_module(module):
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
@ -82,17 +83,28 @@ class LinearAttention(nn.Module):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
q, k, v = rearrange(
qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3,
)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w,
)
return self.to_out(out)
@ -102,26 +114,18 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
@ -131,12 +135,12 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
@ -146,16 +150,18 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@ -163,8 +169,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
@ -175,7 +180,9 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)
)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@ -194,19 +201,37 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
@ -223,29 +248,43 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module):
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
@ -97,37 +100,45 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
"""Learned 2x upsampling without padding"""
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
def forward(self,x):
def forward(self, x):
return self.up(x)
@ -140,7 +151,9 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -149,7 +162,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
@ -219,7 +237,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels
if use_scale_shift_norm
else self.out_channels,
),
)
self.out_layers = nn.Sequential(
@ -227,7 +247,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1
)
),
)
@ -238,7 +260,9 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1
)
def forward(self, x, emb):
"""
@ -251,7 +275,6 @@ class ResBlock(TimestepBlock):
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@ -297,7 +320,7 @@ class AttentionBlock(nn.Module):
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
@ -312,8 +335,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@ -340,7 +365,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1
)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
@staticmethod
@ -397,12 +424,14 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
a = th.einsum(
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length)
@staticmethod
@ -469,11 +498,16 @@ class UNetModel(nn.Module):
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert (
context_dim is not None
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert (
use_spatial_transformer
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
@ -481,10 +515,14 @@ class UNetModel(nn.Module):
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
assert (
num_head_channels != -1
), 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
assert (
num_heads != -1
), 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
@ -545,8 +583,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
@ -554,8 +596,14 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -592,8 +640,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
@ -609,8 +661,14 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResBlock(
ch,
@ -646,8 +704,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
@ -655,8 +717,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
if level and i == num_res_blocks:
@ -673,7 +741,9 @@ class UNetModel(nn.Module):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -682,13 +752,15 @@ class UNetModel(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)
),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
@ -707,7 +779,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@ -718,9 +790,11 @@ class UNetModel(nn.Module):
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
), 'must specify y if and only if the model is class-conditional'
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
@ -768,9 +842,9 @@ class EncoderUNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
pool='adaptive',
*args,
**kwargs
**kwargs,
):
super().__init__()
@ -888,7 +962,7 @@ class EncoderUNetModel(nn.Module):
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
if pool == 'adaptive':
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
@ -896,7 +970,7 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
elif pool == 'attention':
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
@ -905,13 +979,13 @@ class EncoderUNetModel(nn.Module):
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
elif pool == 'spatial':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
elif pool == 'spatial_v2':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
@ -919,7 +993,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self):
"""
@ -942,20 +1016,21 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)

View File

@ -18,15 +18,24 @@ from einops import repeat
from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == 'linear':
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64,
)
** 2
)
elif schedule == "cosine":
elif schedule == 'cosine':
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
@ -34,23 +43,41 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
elif schedule == 'sqrt_linear':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == 'sqrt':
betas = (
torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
ddim_timesteps = (
(
np.linspace(
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
)
)
** 2
).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
@ -60,17 +87,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
def make_ddim_sampling_parameters(
alphacums, ddim_timesteps, eta, verbose=True
):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
alphas_prev = np.asarray(
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
)
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
@ -109,7 +146,9 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if False: # disabled checkpointing to allow requires_grad = False for main model
if (
False
): # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
@ -129,7 +168,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
@ -160,12 +201,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
@ -215,6 +260,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@ -225,7 +271,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
@ -245,15 +291,16 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config
)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
@ -262,6 +309,8 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
logtwopi
+ self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().

View File

@ -10,24 +10,30 @@ class LitEma(nn.Module):
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
#remove as '.'-character is not allowed in buffers
s_name = name.replace('.','')
self.m_name2s_name.update({name:s_name})
self.register_buffer(s_name,p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self,model):
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
decay = min(
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay
@ -38,8 +44,12 @@ class LitEma(nn.Module):
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key]
)
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
@ -48,7 +58,9 @@ class LitEma(nn.Module):
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data
)
else:
assert not key in self.m_name2s_name

View File

@ -8,18 +8,29 @@ from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer
from functools import partial
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
DEFAULT_PLACEHOLDER_TOKEN = ['*']
PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids']
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
@ -28,6 +39,7 @@ def get_bert_token_for_string(tokenizer, string):
return token
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0]
@ -41,7 +53,7 @@ class EmbeddingManager(nn.Module):
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
**kwargs
**kwargs,
):
super().__init__()
@ -49,21 +61,32 @@ class EmbeddingManager(nn.Module):
self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
self.initial_embeddings = (
nn.ParameterDict()
) # These should not be optimized
self.progressive_words = progressive_words
self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
if hasattr(
embedder, 'tokenizer'
): # using Stable Diffusion's CLIP encoder
self.is_clip = True
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
get_token_for_string = partial(
get_clip_token_for_string, embedder.tokenizer
)
get_embedding_for_tkn = partial(
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280
else: # using LDM's BERT encoder
self.is_clip = False
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
get_token_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn
)
get_embedding_for_tkn = embedder.transformer.token_emb
token_dim = 1280
@ -78,12 +101,31 @@ class EmbeddingManager(nn.Module):
init_word_token = get_token_for_string(initializer_words[idx])
with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
init_word_embedding = get_embedding_for_tkn(
init_word_token.cpu()
)
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
token_params = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=True,
)
self.initial_embeddings[
placeholder_string
] = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=False,
)
else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
token_params = torch.nn.Parameter(
torch.rand(
size=(num_vectors_per_token, token_dim),
requires_grad=True,
)
)
self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params
@ -95,36 +137,69 @@ class EmbeddingManager(nn.Module):
):
b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
for (
placeholder_string,
placeholder_token,
) in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
placeholder_embedding = self.string_to_param_dict[
placeholder_string
].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
if (
self.max_vectors_per_token == 1
): # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(
tokenized_text == placeholder_token.to(device)
)
embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words:
self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
max_step_tokens = (
1 + self.progressive_counter // PROGRESSIVE_SCALE
)
else:
max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
num_vectors_for_token = min(
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if placeholder_rows.nelement() == 0:
continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
sorted_cols, sort_idx = torch.sort(
placeholder_cols, descending=True
)
sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)):
row = sorted_rows[idx]
col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
new_token_row = torch.cat(
[
tokenized_text[row][:col],
placeholder_token.repeat(num_vectors_for_token).to(
device
),
tokenized_text[row][col + 1 :],
],
axis=0,
)[:n]
new_embed_row = torch.cat(
[
embedded_text[row][:col],
placeholder_embedding[:num_vectors_for_token],
embedded_text[row][col + 1 :],
],
axis=0,
)[:n]
embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row
@ -132,18 +207,27 @@ class EmbeddingManager(nn.Module):
return embedded_text
def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict,
"string_to_param": self.string_to_param_dict}, ckpt_path)
torch.save(
{
'string_to_token': self.string_to_token_dict,
'string_to_param': self.string_to_param_dict,
},
ckpt_path,
)
def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt["string_to_token"]
self.string_to_param_dict = ckpt["string_to_param"]
self.string_to_token_dict = ckpt['string_to_token']
self.string_to_param_dict = ckpt['string_to_param']
def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
all_params = torch.cat(
list(self.string_to_param_dict.values()), axis=0
) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(
axis=-1
) # num_placeholders
return param_norm_squared
@ -152,13 +236,18 @@ class EmbeddingManager(nn.Module):
def embedding_to_coarse_loss(self):
loss = 0.
loss = 0.0
num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
loss = (
loss
+ (optimized - coarse)
@ (optimized - coarse).T
/ num_embeddings
)
return loss

View File

@ -6,20 +6,29 @@ from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len = None):
def _expand_mask(mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
@ -30,6 +39,7 @@ def _build_causal_attention_mask(bsz, seq_len, dtype):
mask = mask.unsqueeze(1) # expand mask
return mask
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
@ -38,7 +48,6 @@ class AbstractEncoder(nn.Module):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
@ -56,11 +65,17 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(
self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda'
):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
@ -72,27 +87,42 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
from transformers import (
BertTokenizerFast,
) # TODO: add to reuquirements
# Modified to allow to run on non-internet connected compute nodes.
# Model needs to be loaded into cache from an internet-connected machine
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True
)
except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.")
raise SystemExit(
"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine."
)
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
return tokens
@torch.no_grad()
@ -108,53 +138,84 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len
)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
emb_dropout=embedding_dropout,
)
def forward(self, text, embedding_manager=None):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
z = self.transformer(
tokens, return_embeddings=True, embedding_manager=embedding_manager
)
return z
def encode(self, text, **kwargs):
# output of length 77
return self(text, **kwargs)
class SpatialRescaler(nn.Module):
def __init__(self,
def __init__(
self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
bias=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
assert method in [
'nearest',
'linear',
'bilinear',
'trilinear',
'bicubic',
'area',
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.interpolator = partial(
torch.nn.functional.interpolate, mode=method
)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
print(
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias
)
def forward(self,x):
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
@ -162,25 +223,40 @@ class SpatialRescaler(nn.Module):
def encode(self, x):
return self(x)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=True)
self.transformer = CLIPTextModel.from_pretrained(version,local_files_only=True)
self.tokenizer = CLIPTokenizer.from_pretrained(
version, local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True
)
self.device = device
self.max_length = max_length
self.freeze()
def embedding_forward(
self,
input_ids = None,
position_ids = None,
inputs_embeds = None,
embedding_manager = None,
input_ids=None,
position_ids=None,
inputs_embeds=None,
embedding_manager=None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
@ -191,28 +267,39 @@ class FrozenCLIPEmbedder(AbstractEncoder):
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
self.transformer.text_model.embeddings.forward = (
embedding_forward.__get__(self.transformer.text_model.embeddings)
)
def encoder_forward(
self,
inputs_embeds,
attention_mask = None,
causal_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
attention_mask=None,
causal_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@ -239,44 +326,61 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
raise ValueError('You have to specify either input_ids')
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
causal_attention_mask = _build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype
).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
attention_mask = _expand_mask(
attention_mask, hidden_states.dtype
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
@ -291,17 +395,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
return self.text_model(
input_ids=input_ids,
@ -310,11 +416,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager = embedding_manager
embedding_manager=embedding_manager,
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
self.transformer.forward = transformer_forward.__get__(
self.transformer
)
def freeze(self):
self.transformer = self.transformer.eval()
@ -322,9 +429,16 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs)
return z
@ -337,9 +451,17 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(
self,
version='ViT-L/14',
device='cuda',
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.model, _ = clip.load(version, jit=False, device='cpu')
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
@ -359,7 +481,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
def encode(self, text):
z = self(text)
if z.ndim==2:
if z.ndim == 2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
@ -369,6 +491,7 @@ class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
@ -381,15 +504,27 @@ class FrozenClipImageEmbedder(nn.Module):
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False,
)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False,
)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
x = kornia.geometry.resize(
x,
(224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
@ -399,7 +534,8 @@ class FrozenClipImageEmbedder(nn.Module):
return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__":
if __name__ == '__main__':
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)

View File

@ -1,2 +1,6 @@
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
from ldm.modules.image_degradation.bsrgan import (
degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
return im[: w - w % sf, : h - h % sf, ...]
"""
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k):
'''
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
Q = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs):
'''
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
"""
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3):
'''
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681},
year={2019}
}
'''
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
@ -328,10 +350,19 @@ def add_blur(img, sf=4):
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
k = fspecial(
'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img
@ -344,7 +375,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
@ -366,19 +401,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -388,28 +430,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@ -418,7 +469,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@ -428,10 +481,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
@ -452,7 +509,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -462,8 +519,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
@ -472,7 +532,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -487,19 +550,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
@ -544,15 +618,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
@ -561,7 +638,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -576,19 +656,33 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
@ -609,12 +703,19 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image":image}
example = {'image': image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
def degradation_bsrgan_plus(
img,
sf=4,
shuffle_prob=0.5,
use_sharp=True,
lq_patchsize=64,
isp_model=None,
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
@ -630,7 +731,7 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
"""
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -645,8 +746,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
shuffle_order[2:6] = random.sample(
shuffle_order[2:6], len(range(2, 6))
)
shuffle_order[9:13] = random.sample(
shuffle_order[9:13], len(range(9, 13))
)
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
@ -689,8 +794,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
print('check the shuffle!')
# resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise
img = add_JPEG_noise(img)
@ -702,29 +810,37 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
if __name__ == '__main__':
print("hey")
print('hey')
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
print('resizing to', h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img)['image']
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
return im[: w - w % sf, : h - h % sf, ...]
"""
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k):
'''
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
Q = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs):
'''
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
"""
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3):
'''
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681},
year={2019}
}
'''
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
@ -326,16 +348,25 @@ def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2/4
wd = wd/4
wd2 = wd2 / 4
wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
k = fspecial(
'gaussian', random.randint(2, 4) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img
@ -348,7 +379,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
@ -370,19 +405,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -392,28 +434,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@ -422,7 +473,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@ -432,10 +485,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
@ -456,7 +513,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -466,8 +523,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
@ -476,7 +536,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -491,19 +554,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
@ -548,15 +622,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
@ -565,7 +642,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -583,20 +663,34 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
@ -617,34 +711,41 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
example = {'image': image}
return example
if __name__ == '__main__':
print("hey")
print('hey')
img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
print('resizing to', h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_lq = deg_fn(img)['image']
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)['image']
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@ -6,13 +6,14 @@ import torch
import cv2
from torchvision.utils import make_grid
from datetime import datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
'''
"""
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
@ -20,10 +21,22 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
"""
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
'.tif',
]
def is_image_file(filename):
@ -49,19 +62,19 @@ def surf(Z, cmap='rainbow', figsize=None):
ax3 = plt.axes(projection='3d')
w, h = Z.shape[:2]
xx = np.arange(0,w,1)
yy = np.arange(0,h,1)
xx = np.arange(0, w, 1)
yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X,Y,Z,cmap=cmap)
#ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
ax3.plot_surface(X, Y, Z, cmap=cmap)
# ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
plt.show()
'''
"""
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
"""
def get_image_paths(dataroot):
@ -83,26 +96,26 @@ def _get_paths_from_images(path):
return images
'''
"""
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
"""
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
w1.append(w-p_size)
h1.append(h-p_size)
# print(w1)
# print(h1)
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
w1.append(w - p_size)
h1.append(h - p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i:i+p_size, j:j+p_size,:])
patches.append(img[i : i + p_size, j : j + p_size, :])
else:
patches.append(img)
@ -118,11 +131,21 @@ def imssave(imgs, img_path):
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
new_path = os.path.join(
os.path.dirname(img_path),
img_name + str('_s{:04d}'.format(i)) + '.png',
)
cv2.imwrite(new_path, img)
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
@ -139,15 +162,18 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800,
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
#if original_dataroot == taget_dataroot:
#del img_path
imssave(
patches, os.path.join(taget_dataroot, os.path.basename(img_path))
)
# if original_dataroot == taget_dataroot:
# del img_path
'''
"""
# --------------------------------------------
# makedir
# --------------------------------------------
'''
"""
def mkdir(path):
@ -171,12 +197,12 @@ def mkdir_and_rename(path):
os.makedirs(path)
'''
"""
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
"""
# --------------------------------------------
@ -206,6 +232,7 @@ def imsave(img, img_path):
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
@ -213,7 +240,6 @@ def imwrite(img, img_path):
cv2.imwrite(img_path, img)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
@ -221,7 +247,7 @@ def read_img(path):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255.
img = img.astype(np.float32) / 255.0
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
@ -230,7 +256,7 @@ def read_img(path):
return img
'''
"""
# --------------------------------------------
# image format conversion
# --------------------------------------------
@ -238,7 +264,7 @@ def read_img(path):
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
"""
# --------------------------------------------
@ -248,22 +274,22 @@ def read_img(path):
def uint2single(img):
return np.float32(img/255.)
return np.float32(img / 255.0)
def single2uint(img):
return np.uint8((img.clip(0, 1)*255.).round())
return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img):
return np.float32(img/65535.)
return np.float32(img / 65535.0)
def single2uint16(img):
return np.uint16((img.clip(0, 1)*65535.).round())
return np.uint16((img.clip(0, 1) * 65535.0).round())
# --------------------------------------------
@ -275,14 +301,25 @@ def single2uint16(img):
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
)
# convert 2/3/4-dimensional torch tensor to uint
@ -290,7 +327,7 @@ def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return np.uint8((img*255.0).round())
return np.uint8((img * 255.0).round())
# --------------------------------------------
@ -305,7 +342,12 @@ def single2tensor3(img):
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
# convert torch tensor to single
@ -316,6 +358,7 @@ def tensor2single(img):
return img
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
@ -327,30 +370,48 @@ def tensor2single3(img):
def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.float()
.unsqueeze(0)
.unsqueeze(0)
)
def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
)
# from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
'''
"""
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
"""
tensor = (
tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (
min_max[1] - min_max[0]
) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = make_grid(
tensor, nrow=int(math.sqrt(n_img)), normalize=False
).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
@ -359,14 +420,17 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
img_np = tensor.numpy()
else:
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(
n_dim
)
)
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
'''
"""
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
@ -374,12 +438,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
"""
def augment_img(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
@ -399,8 +462,7 @@ def augment_img(img, mode=0):
def augment_img_tensor4(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
@ -420,8 +482,7 @@ def augment_img_tensor4(img, mode=0):
def augment_img_tensor(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
img_size = img.size()
img_np = img.data.cpu().numpy()
if len(img_size) == 3:
@ -484,11 +545,11 @@ def augment_imgs(img_list, hflip=True, rot=True):
return [_augment(img) for img in img_list]
'''
"""
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
"""
def modcrop(img_in, scale):
@ -497,11 +558,11 @@ def modcrop(img_in, scale):
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r]
img = img[: H - H_r, : W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r, :]
img = img[: H - H_r, : W - W_r, :]
else:
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
return img
@ -511,11 +572,11 @@ def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
img = img[border:h-border, border:w-border]
img = img[border : h - border, border : w - border]
return img
'''
"""
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
@ -523,74 +584,92 @@ def shave(img_in, border=0):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
"""
def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
"""same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
rlt = np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb
"""same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
rlt = np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr
"""bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
rlt = np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
@ -608,11 +687,11 @@ def channel_convert(in_c, tar_type, img_list):
return img_list
'''
"""
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
"""
# --------------------------------------------
@ -620,17 +699,17 @@ def channel_convert(in_c, tar_type, img_list):
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
@ -640,17 +719,17 @@ def calculate_psnr(img1, img2, border=0):
# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
'''calculate SSIM
"""calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
"""
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
if img1.ndim == 2:
return ssim(img1, img2)
@ -658,7 +737,7 @@ def calculate_ssim(img1, img2, border=0):
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
@ -667,8 +746,8 @@ def calculate_ssim(img1, img2, border=0):
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
@ -684,16 +763,17 @@ def ssim(img1, img2):
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
return ssim_map.mean()
'''
"""
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
"""
# matlab 'imresize' function, now only support 'bicubic'
@ -701,11 +781,14 @@ def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
@ -729,8 +812,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
0, P - 1, P
).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
@ -771,7 +855,11 @@ def imresize(img, scale, antialiasing=True):
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = 'cubic'
@ -782,9 +870,11 @@ def imresize(img, scale, antialiasing=True):
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
@ -805,7 +895,11 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension
# symmetric copying
@ -827,7 +921,9 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(
weights_W[i]
)
if need_squeeze:
out_2.squeeze_()
return out_2
@ -846,7 +942,11 @@ def imresize_np(img, scale, antialiasing=True):
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = 'cubic'
@ -857,9 +957,11 @@ def imresize_np(img, scale, antialiasing=True):
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
@ -880,7 +982,11 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension
# symmetric copying
@ -902,7 +1008,9 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(
weights_W[i]
)
if need_squeeze:
out_2.squeeze_()

View File

@ -5,13 +5,24 @@ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/
class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss='hinge',
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert disc_loss in ['hinge', 'vanilla']
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
@ -19,42 +30,68 @@ class LPIPSWithDiscriminator(nn.Module):
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm
use_actnorm=use_actnorm,
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_loss = (
hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss
)
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train",
weights=None):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
weights=None,
):
rec_loss = torch.abs(
inputs.contiguous() - reconstructions.contiguous()
)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = (
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
)
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
@ -67,27 +104,42 @@ class LPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
log = {
'{}/total_loss'.format(split): loss.clone().detach().mean(),
'{}/logvar'.format(split): self.logvar.detach(),
'{}/kl_loss'.format(split): kl_loss.detach().mean(),
'{}/nll_loss'.format(split): nll_loss.detach().mean(),
'{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
return loss, log
@ -95,17 +147,29 @@ class LPIPSWithDiscriminator(nn.Module):
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
log = {
'{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
'{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@ -3,21 +3,25 @@ from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.discriminator.model import (
NLayerDiscriminator,
weights_init,
)
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
@ -26,57 +30,76 @@ def adopt_weight(weight, global_step, threshold=0, value=0.):
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
encodings = (
F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x-y)
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x-y), 2)
return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
def __init__(
self,
disc_start,
codebook_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss='hinge',
n_classes=None,
perceptual_loss='lpips',
pixel_loss='l1',
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
assert disc_loss in ['hinge', 'vanilla']
assert perceptual_loss in ['lpips', 'clips', 'dists']
assert pixel_loss in ['l1', 'l2']
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
if perceptual_loss == 'lpips':
print(f'{self.__class__.__name__}: Running with LPIPS.')
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
raise ValueError(
f'Unknown perceptual loss: >> {perceptual_loss} <<'
)
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
if pixel_loss == 'l1':
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
if disc_loss == 'hinge':
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
elif disc_loss == 'vanilla':
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.')
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
@ -84,31 +107,53 @@ class VQLPIPSWithDiscriminator(nn.Module):
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
def forward(
self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
predicted_indices=None,
):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(
inputs.contiguous(), reconstructions.contiguous()
)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
@ -119,49 +164,77 @@ class VQLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
log = {
'{}/total_loss'.format(split): loss.clone().detach().mean(),
'{}/quant_loss'.format(split): codebook_loss.detach().mean(),
'{}/nll_loss'.format(split): nll_loss.detach().mean(),
'{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/p_loss'.format(split): p_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
perplexity, cluster_usage = measure_perplexity(
predicted_indices, self.n_classes
)
log[f'{split}/perplexity'] = perplexity
log[f'{split}/cluster_usage'] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
log = {
'{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
'{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
Intermediates = namedtuple(
'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']
)
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates'
])
LayerIntermediates = namedtuple(
'Intermediates', ['hiddens', 'attn_intermediates']
)
class AbsolutePositionalEmbedding(nn.Module):
@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(
self.inv_freq
)
+ offset
)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers
def exists(val):
return val is not None
@ -64,18 +68,21 @@ def default(val, d):
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(
lambda x: (x[0][len(prefix) :], x[1]),
tuple(kwargs_with_prefix.items()),
)
)
return kwargs_without_prefix, kwargs
@ -139,7 +154,7 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@ -151,7 +166,7 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
@ -173,7 +188,7 @@ class GRUGating(nn.Module):
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
rearrange(residual, 'b n d -> (b n) d'),
)
return gated_output.reshape_as(x)
@ -181,6 +196,7 @@ class GRUGating(nn.Module):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@ -192,19 +208,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
@ -224,13 +239,15 @@ class Attention(nn.Module):
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False
dropout=0.0,
on_attn=False,
):
super().__init__()
if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
self.scale = dim_head ** -0.5
raise NotImplementedError(
'Check out entmax activation instead of softmax activation!'
)
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
@ -252,7 +269,7 @@ class Attention(nn.Module):
self.sparse_topk = sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
@ -263,7 +280,11 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward(
self,
@ -274,9 +295,14 @@ class Attention(nn.Module):
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None
mem=None,
):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x)
q_input = x
@ -297,23 +323,35 @@ class Attention(nn.Module):
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)
)
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
q_mask = default(
mask, lambda: torch.ones((b, n), device=device).bool()
)
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
k_mask = default(
k_mask,
lambda: torch.ones((b, k.shape[-2]), device=device).bool(),
)
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
mem_k, mem_v = map(
lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v),
)
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True
)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
@ -324,7 +362,9 @@ class Attention(nn.Module):
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
dots = einsum(
'b h i j, h k -> b k i j', dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
@ -336,7 +376,9 @@ class Attention(nn.Module):
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j'
)
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
@ -354,14 +396,16 @@ class Attention(nn.Module):
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
attn = einsum(
'b h i j, h k -> b k i j', attn, self.post_softmax_proj
).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
post_softmax_attn=post_softmax_attn,
)
return self.to_out(out), intermediates
@ -390,7 +434,7 @@ class AttentionLayers(nn.Module):
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
@ -403,10 +447,14 @@ class AttentionLayers(nn.Module):
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None
self.pre_norm = pre_norm
@ -438,15 +486,27 @@ class AttentionLayers(nn.Module):
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
assert (
len(default_block) <= par_width
), 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (
par_width - len(default_block)
)
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
assert (
sandwich_coef > 0 and sandwich_coef <= depth
), 'sandwich coefficient should be less than the depth'
layer_types = (
('a',) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ('f',) * sandwich_coef
)
else:
layer_types = default_block * depth
@ -455,7 +515,9 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs
)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
@ -472,11 +534,7 @@ class AttentionLayers(nn.Module):
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([
norm_fn(),
layer,
residual_fn
]))
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(
self,
@ -486,7 +544,7 @@ class AttentionLayers(nn.Module):
context_mask=None,
mems=None,
return_hiddens=False,
**kwargs
**kwargs,
):
hiddens = []
intermediates = []
@ -495,7 +553,9 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
@ -508,10 +568,22 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
prev_attn=prev_attn, mem=layer_mem)
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == 'f':
out = block(x)
@ -530,8 +602,7 @@ class AttentionLayers(nn.Module):
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates
hiddens=hiddens, attn_intermediates=intermediates
)
return x, intermediates
@ -545,7 +616,6 @@ class Encoder(AttentionLayers):
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
@ -554,14 +624,16 @@ class TransformerWrapper(nn.Module):
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
max_mem_len=0.0,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
use_pos_emb=True,
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
@ -571,23 +643,34 @@ class TransformerWrapper(nn.Module):
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.project_emb = (
nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
)
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim)
)
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
@ -605,7 +688,7 @@ class TransformerWrapper(nn.Module):
return_attn=False,
mems=None,
embedding_manager=None,
**kwargs
**kwargs,
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
@ -629,7 +712,9 @@ class TransformerWrapper(nn.Module):
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
@ -638,13 +723,30 @@ class TransformerWrapper(nn.Module):
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
new_mems = (
list(
map(
lambda pair: torch.cat(pair, dim=-2),
zip(mems, hiddens),
)
)
if exists(mems)
else hiddens
)
new_mems = list(
map(
lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems
)
)
return out, new_mems
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = list(
map(
lambda t: t.post_softmax_attn,
intermediates.attn_intermediates,
)
)
return out, attn_maps
return out

View File

@ -111,19 +111,21 @@ class T2I:
strength
embedding_path
The vast majority of these arguments default to reasonable values.
The vast majority of these arguments default to reasonable values.
"""
def __init__(self,
def __init__(
self,
batch_size=1,
iterations = 1,
iterations=1,
steps=50,
seed=None,
cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/stable-diffusion/v1-inference.yaml",
weights='models/ldm/stable-diffusion-v1/model.ckpt',
config='configs/stable-diffusion/v1-inference.yaml',
width=512,
height=512,
sampler_name="klms",
sampler_name='klms',
latent_channels=4,
downsampling_factor=8,
ddim_eta=0.0, # deterministic
@ -153,7 +155,7 @@ The vast majority of these arguments default to reasonable values.
self.embedding_path = embedding_path
self.model = None # empty for now
self.sampler = None
self.latent_diffusion_weights=latent_diffusion_weights
self.latent_diffusion_weights = latent_diffusion_weights
self.device = device
self.gfpgan = gfpgan
if seed is None:
@ -162,29 +164,34 @@ The vast majority of these arguments default to reasonable values.
self.seed = seed
transformers.logging.set_verbosity_error()
def prompt2png(self,prompt,outdir,**kwargs):
'''
def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to T2I and prompt2image()
'''
results = self.prompt2image(prompt,**kwargs)
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size))
"""
results = self.prompt2image(prompt, **kwargs)
pngwriter = PngWriter(
outdir, prompt, kwargs.get('batch_size', self.batch_size)
)
for r in results:
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
pngwriter.write_image(r[0],r[1])
pngwriter.write_image(r[0], r[1])
return pngwriter.files_written
def txt2img(self,prompt,**kwargs):
outdir = kwargs.get('outdir','outputs/img-samples')
return self.prompt2png(prompt,outdir,**kwargs)
def txt2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir', 'outputs/img-samples')
return self.prompt2png(prompt, outdir, **kwargs)
def img2img(self,prompt,**kwargs):
outdir = kwargs.get('outdir','outputs/img-samples')
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
return self.prompt2png(prompt,outdir,**kwargs)
def img2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir', 'outputs/img-samples')
assert (
'init_img' in kwargs
), 'call to img2img() must include the init_img argument'
return self.prompt2png(prompt, outdir, **kwargs)
def prompt2image(self,
def prompt2image(
self,
# these are common
prompt,
batch_size=None,
@ -203,8 +210,9 @@ The vast majority of these arguments default to reasonable values.
strength=None,
gfpgan_strength=None,
variants=None,
**args): # eat up additional cruft
'''
**args,
): # eat up additional cruft
"""
ldm.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments:
prompt // prompt string (no default)
@ -232,7 +240,7 @@ The vast majority of these arguments default to reasonable values.
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
to create the requested output directory, select a unique informative name for each image, and
write the prompt into the PNG metadata.
'''
"""
steps = steps or self.steps
seed = seed or self.seed
width = width or self.width
@ -243,107 +251,146 @@ The vast majority of these arguments default to reasonable values.
iterations = iterations or self.iterations
strength = strength or self.strength
model = self.load_model() # will instantiate the model or return it from cache
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
w = int(width/64) * 64
h = int(height/64) * 64
model = (
self.load_model()
) # will instantiate the model or return it from cache
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert (
0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]'
w = int(width / 64) * 64
h = int(height / 64) * 64
if h != height or w != width:
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}')
print(
f'Height and width must be multiples of 64. Resizing to {h}x{w}'
)
height = h
width = w
scope = autocast if self.precision=="autocast" else nullcontext
scope = autocast if self.precision == 'autocast' else nullcontext
tic = time.time()
results = list()
try:
if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found'
images_iterator = self._img2img(prompt,
assert os.path.exists(init_img), f'{init_img}: File not found'
images_iterator = self._img2img(
prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
steps=steps,
cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,strength=strength)
init_img=init_img,
strength=strength,
)
else:
images_iterator = self._txt2img(prompt,
images_iterator = self._txt2img(
prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
steps=steps,
cfg_scale=cfg_scale,
ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,height=height)
width=width,
height=height,
)
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"):
for n in trange(iterations, desc='Sampling'):
seed_everything(seed)
iter_images = next(images_iterator)
for image in iter_images:
try:
if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength)
image = self._run_gfpgan(
image, gfpgan_strength
)
except Exception as e:
print(f"Error running GFPGAN - Your image was not enhanced.\n{e}")
print(
f'Error running GFPGAN - Your image was not enhanced.\n{e}'
)
results.append([image, seed])
if image_callback is not None:
image_callback(image,seed)
image_callback(image, seed)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
print(
'Partial results will be returned; if --grid was requested, nothing will be returned.'
)
except RuntimeError as e:
print(str(e))
print('Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
print(f'{len(results)} images generated in', '%4.2fs' % (toc - tic))
return results
@torch.no_grad()
def _txt2img(self,
def _txt2img(
self,
prompt,
precision_scope,
batch_size,
steps,cfg_scale,ddim_eta,
steps,
cfg_scale,
ddim_eta,
skip_normalize,
width,height):
width,
height,
):
"""
An infinite iterator of images from the prompt.
"""
sampler = self.sampler
while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples, _ = sampler.sample(S=steps,
shape = [
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
]
samples, _ = sampler.sample(
S=steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta)
eta=ddim_eta,
)
yield self._samples_to_images(samples)
@torch.no_grad()
def _img2img(self,
def _img2img(
self,
prompt,
precision_scope,
batch_size,
steps,cfg_scale,ddim_eta,
steps,
cfg_scale,
ddim_eta,
skip_normalize,
init_img,strength):
init_img,
strength,
):
"""
An infinite iterator of images from the prompt and the initial image
"""
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
if self.sampler_name != 'ddim':
print(
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
)
sampler = DDIMSampler(self.model, device=self.device)
else:
sampler = self.sampler
@ -351,9 +398,13 @@ The vast majority of these arguments default to reasonable values.
init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps")
@ -362,30 +413,43 @@ The vast majority of these arguments default to reasonable values.
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(self.device)
)
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
uc = self.model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [''])
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompt)
subprompts, weights = T2I._split_weighted_subprompts(prompt)
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
for i in range(0, len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight)
c = torch.add(
c,
self.model.get_learned_conditioning(
batch_size * [subprompts[i]]
),
alpha=weight,
)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c)
@ -395,23 +459,29 @@ The vast majority of these arguments default to reasonable values.
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list()
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), 'c h w -> h w c'
)
image = Image.fromarray(x_sample.astype(np.uint8))
images.append(image)
return images
def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
def load_model(self):
""" Load and initialize the model from configuration variables passed at object creation time """
"""Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None:
seed_everything(self.seed)
try:
config = OmegaConf.load(self.config)
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights)
self.device = (
torch.device(self.device)
if torch.cuda.is_available()
else torch.device('cpu')
)
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(self.embedding_path)
self.model = model.to(self.device)
@ -421,18 +491,26 @@ The vast majority of these arguments default to reasonable values.
raise SystemExit
msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms':
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
self.sampler = KSampler(
self.model, 'dpm_2', device=self.device
)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
self.sampler = KSampler(
self.model, 'euler', device=self.device
)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
@ -446,32 +524,38 @@ The vast majority of these arguments default to reasonable values.
return self.model
def _load_model_from_config(self, config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(self.device)
model.eval()
if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)')
print(
'Using slower but more accurate full-precision math (--full_precision)'
)
else:
print('Using half precision math. Call with --full_precision to use slower but more accurate full precision.')
print(
'Using half precision math. Call with --full_precision to use slower but more accurate full precision.'
)
model.half()
return model
def _load_img(self,path):
image = Image.open(path).convert("RGB")
def _load_img(self, path):
image = Image.open(path).convert('RGB')
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
print(f'loaded input image of size ({w}, {h}) from {path}')
w, h = map(
lambda x: x - x % 32, (w, h)
) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
return 2.0 * image - 1.0
def _split_weighted_subprompts(text):
"""
@ -484,29 +568,31 @@ The vast majority of these arguments default to reasonable values.
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
if ':' in text:
idx = text.index(':') # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
text = text[idx + 1 :]
# find value for weight
if " " in text:
idx = text.index(" ") # first occurence
if ' ' in text:
idx = text.index(' ') # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
print(
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
)
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
text = text[idx + 1 :]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
@ -519,13 +605,20 @@ The vast majority of these arguments default to reasonable values.
return prompts, weights
def _run_gfpgan(self, image, strength):
if (self.gfpgan is None):
print(f"GFPGAN not initialized, it must be loaded via the --gfpgan argument")
if self.gfpgan is None:
print(
f'GFPGAN not initialized, it must be loaded via the --gfpgan argument'
)
return image
image = image.convert("RGB")
image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:

View File

@ -13,22 +13,25 @@ from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
lines = '\n'.join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -70,22 +73,26 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config, **kwargs):
if not "target" in config:
if not 'target' in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
elif config == '__is_unconditional__':
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
module, cls = string.rsplit('.', 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@ -101,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
Q.put('Done')
def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
func: callable,
data,
n_proc,
target_data_type='ndarray',
cpu_intensive=True,
use_worker_id=False,
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
if target_data_type == 'ndarray':
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
)
if cpu_intensive:
@ -135,7 +147,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
if target_data_type == 'ndarray':
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@ -149,7 +161,7 @@ def parallel_data_prefetch(
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)]
[data[i : i + step] for i in range(0, len(data), step)]
)
]
processes = []
@ -158,7 +170,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print(f"Start prefetching...")
print(f'Start prefetching...')
import time
start = time.time()
@ -171,13 +183,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
if res == 'Done':
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
print('Exception: ', e)
for p in processes:
p.terminate()
@ -185,7 +197,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
print(f'Prefetching complete. [{time.time() - start} sec.]')
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):

659
main.py

File diff suppressed because it is too large Load Diff

View File

@ -8,41 +8,45 @@ import sys
import copy
import warnings
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter,PromptFormatter
from ldm.dream.pngwriter import PngWriter, PromptFormatter
debugging = False
def main():
''' Initialize command-line parsers and the diffusion model '''
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
if opt.laion400m:
# defaults suitable to the older latent diffusion weights
width = 256
height = 256
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
weights = "models/ldm/text2img-large/model.ckpt"
config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml'
weights = 'models/ldm/text2img-large/model.ckpt'
else:
# some defaults suitable for stable diffusion weights
width = 512
height = 512
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
config = 'configs/stable-diffusion/v1-inference.yaml'
weights = 'models/ldm/stable-diffusion-v1/model.ckpt'
print("* Initializing, be patient...\n")
print('* Initializing, be patient...\n')
sys.path.append('.')
from pytorch_lightning import logging
from ldm.simplet2i import T2I
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
t2i = T2I(width=width,
t2i = T2I(
width=width,
height=height,
sampler_name=opt.sampler_name,
weights=weights,
@ -50,7 +54,7 @@ def main():
config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
embedding_path=opt.embedding_path,
device=opt.device
device=opt.device,
)
# make sure the output directory exists
@ -58,12 +62,12 @@ def main():
os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
infile = None
try:
if opt.infile is not None:
infile = open(opt.infile,'r')
infile = open(opt.infile, 'r')
except FileNotFoundError as e:
print(e)
exit(-1)
@ -73,59 +77,74 @@ def main():
# load GFPGAN if requested
if opt.use_gfpgan:
print("\n* --gfpgan was specified, loading gfpgan...")
print('\n* --gfpgan was specified, loading gfpgan...')
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
try:
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
model_path = os.path.join(
opt.gfpgan_dir, opt.gfpgan_model_path
)
if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path)
raise Exception(
'GFPGAN model not found at path ' + model_path
)
sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer
bg_upsampler = load_gfpgan_bg_upsampler(opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile)
bg_upsampler = load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile
)
t2i.gfpgan = GFPGANer(model_path=model_path, upscale=opt.gfpgan_upscale, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler)
t2i.gfpgan = GFPGANer(
model_path=model_path,
upscale=opt.gfpgan_upscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception:
import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print('Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...")
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..."
)
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
log_path = os.path.join(opt.outdir, 'dream_log.txt')
with open(log_path, 'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
main_loop(t2i, opt.outdir, cmd_parser, log, infile)
log.close()
if infile:
infile.close()
def main_loop(t2i,outdir,parser,log,infile):
''' prompt/read/execute loop '''
def main_loop(t2i, outdir, parser, log, infile):
"""prompt/read/execute loop"""
done = False
last_seeds = []
while not done:
try:
command = infile.readline() if infile else input("dream> ")
command = infile.readline() if infile else input('dream> ')
except EOFError:
done = True
break
if infile and len(command)==0:
if infile and len(command) == 0:
done = True
break
if command.startswith(('#','//')):
if command.startswith(('#', '//')):
continue
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'","\\'")
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
@ -133,26 +152,28 @@ def main_loop(t2i,outdir,parser,log,infile):
print(str(e))
continue
if len(elements)==0:
if len(elements) == 0:
continue
if elements[0]=='q':
if elements[0] == 'q':
done = True
break
if elements[0]=='cd' and len(elements)>1:
if elements[0] == 'cd' and len(elements) > 1:
if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}")
outdir=elements[1]
print(f'setting image output directory to {elements[1]}')
outdir = elements[1]
else:
print(f"directory {elements[1]} does not exist")
print(f'directory {elements[1]} does not exist')
continue
if elements[0]=='pwd':
print(f"current output directory is {outdir}")
if elements[0] == 'pwd':
print(f'current output directory is {outdir}')
continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
if elements[0].startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
elements.pop(0)
# rearrange the arguments to mimic how it works in the Dream bot.
@ -160,48 +181,52 @@ def main_loop(t2i,outdir,parser,log,infile):
switches_started = False
for el in elements:
if el[0]=='-' and not switches_started:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][:len(switches[0])-1]
switches[0] = switches[0][: len(switches[0]) - 1]
try:
opt = parser.parse_args(switches)
except SystemExit:
parser.print_help()
continue
if len(opt.prompt)==0:
print("Try again with a prompt!")
if len(opt.prompt) == 0:
print('Try again with a prompt!')
continue
if opt.seed is not None and opt.seed<0: # retrieve previous value!
if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try:
opt.seed = last_seeds[opt.seed]
print(f"reusing previous seed {opt.seed}")
print(f'reusing previous seed {opt.seed}')
except IndexError:
print(f"No previous seed at position {opt.seed} found")
print(f'No previous seed at position {opt.seed} found')
opt.seed = None
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
individual_images = not opt.grid
try:
file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size)
file_writer = PngWriter(outdir, normalized_prompt, opt.batch_size)
callback = file_writer.write_image if individual_images else None
image_list = t2i.prompt2image(image_callback=callback,**vars(opt))
results = file_writer.files_written if individual_images else image_list
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
results = (
file_writer.files_written if individual_images else image_list
)
if opt.grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results])
filename = file_writer.unique_filename(results[0][1])
seeds = [a[1] for a in results]
results = [[filename,seeds]]
results = [[filename, seeds]]
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename)
file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)
last_seeds = [r[1] for r in results]
@ -213,10 +238,11 @@ def main_loop(t2i,outdir,parser,log,infile):
print(e)
continue
print("Outputs:")
write_log_message(t2i,normalized_prompt,results,log)
print('Outputs:')
write_log_message(t2i, normalized_prompt, results, log)
print('goodbye!')
print("goodbye!")
def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
import torch
@ -224,13 +250,24 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
warnings.warn(
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.'
)
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
@ -238,12 +275,14 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
half=True,
) # need to set False in CPU mode
else:
bg_upsampler = None
return bg_upsampler
# variant generation is going to be superseded by a generalized
# "prompt-morph" functionality
# def generate_variants(t2i,outdir,opt,previous_gens):
@ -270,108 +309,207 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
# return variants
def write_log_message(t2i,prompt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata'''
def write_log_message(t2i, prompt, results, logfile):
"""logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata"""
last_seed = None
img_num = 1
seenit = {}
for r in results:
seed = r[1]
log_message = (f'{r[0]}: {prompt} -S{seed}')
log_message = f'{r[0]}: {prompt} -S{seed}'
print(log_message)
logfile.write(log_message+"\n")
logfile.write(log_message + '\n')
logfile.flush()
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
"--latent_diffusion",
"-l",
parser = argparse.ArgumentParser(
description="Parse script's command line args"
)
parser.add_argument(
'--laion400m',
'--latent_diffusion',
'-l',
dest='laion400m',
action='store_true',
help="fallback to the latent diffusion (laion400m) weights and config")
parser.add_argument("--from_file",
help='fallback to the latent diffusion (laion400m) weights and config',
)
parser.add_argument(
'--from_file',
dest='infile',
type=str,
help="if specified, load prompts from this file")
parser.add_argument('-n','--iterations',
help='if specified, load prompts from this file',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help="number of images to generate")
parser.add_argument('-F','--full_precision',
help='number of images to generate',
)
parser.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
help='use slower full precision math for calculations',
)
parser.add_argument(
'--sampler',
'-m',
dest='sampler_name',
choices=[
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
],
default='k_lms',
help="which sampler to use (k_lms) - can only be set on command line")
parser.add_argument('--outdir',
help='which sampler to use (k_lms) - can only be set on command line',
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default="outputs/img-samples",
help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples")
parser.add_argument('--embedding_path',
default='outputs/img-samples',
help='directory in which to place generated images and a log of prompts and seeds (outputs/img-samples',
)
parser.add_argument(
'--embedding_path',
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
parser.add_argument('--device',
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
parser.add_argument(
'--device',
'-d',
type=str,
default="cuda",
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible")
default='cuda',
help='device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible',
)
# GFPGAN related args
parser.add_argument('--gfpgan',
parser.add_argument(
'--gfpgan',
dest='use_gfpgan',
action='store_true',
help="load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory")
parser.add_argument("--gfpgan_upscale",
help='load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory',
)
parser.add_argument(
'--gfpgan_upscale',
type=int,
default=2,
help="The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified")
parser.add_argument("--gfpgan_bg_upsampler",
help='The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_bg_upsampler',
type=str,
default='realesrgan',
help="Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified")
parser.add_argument("--gfpgan_bg_tile",
help='Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_bg_tile',
type=int,
default=400,
help="Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified")
parser.add_argument("--gfpgan_model_path",
help='Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_model_path',
type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth',
help="indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified")
parser.add_argument("--gfpgan_dir",
help='indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_dir',
type=str,
default='../GFPGAN',
help="indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified")
help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified',
)
return parser
def create_cmd_parser():
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12')
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform (slower, but will provide seeds for individual images)")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (will not provide seeds for individual images!)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
parser.add_argument('-G','--gfpgan_strength', default=0.5, type=float, help="The strength at which to apply the GFPGAN model to the result, in order to improve faces.")
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
parser.add_argument('-s', '--steps', type=int, help='number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-b',
'--batch_size',
type=int,
default=1,
help='number of images to produce per sampling (will not provide seeds for individual images!)',
)
parser.add_argument(
'-W', '--width', type=int, help='image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='prompt configuration scale',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0.5,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='skip subprompt weight normalization',
)
return parser
if __name__ == "__main__":
if __name__ == '__main__':
main()

View File

@ -11,26 +11,28 @@ import warnings
transformers.logging.set_verbosity_error()
# this will preload the Bert tokenizer fles
print("preloading bert tokenizer...")
print('preloading bert tokenizer...')
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success")
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
# this will download requirements for Kornia
print("preloading Kornia requirements (ignore the deprecation warnings)...")
print('preloading Kornia requirements (ignore the deprecation warnings)...')
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print("...success")
print('...success')
version='openai/clip-vit-large-patch14'
version = 'openai/clip-vit-large-patch14'
print('preloading CLIP model (Ignore the deprecation warnings)...')
sys.stdout.flush()
import clip
from transformers import CLIPTokenizer, CLIPTextModel
tokenizer =CLIPTokenizer.from_pretrained(version)
transformer=CLIPTextModel.from_pretrained(version)
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('\n\n...success')
# In the event that the user has installed GFPGAN and also elected to use
@ -38,23 +40,33 @@ print('\n\n...success')
gfpgan = False
try:
from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print("Loading models from RealESRGAN and facexlib")
print('Loading models from RealESRGAN and facexlib')
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(scale=2,
RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2))
FaceRestoreHelper(1,det_model='retinaface_resnet50')
print("...success")
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success')
except Exception:
import traceback
print("Error loading GFPGAN:")
print('Error loading GFPGAN:')
print(traceback.format_exc())