Merge branch 'main' into half-precision-embeddings

This commit is contained in:
Paul Sajna 2022-08-26 08:33:46 -07:00 committed by GitHub
commit 555f13e469
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 7153 additions and 3446 deletions

3
.gitignore vendored
View File

@ -2,6 +2,9 @@
outputs/
models/ldm/stable-diffusion-v1/model.ckpt
# ignore a directory which serves as a place for initial images
inputs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

13
.gitmodules vendored
View File

@ -1,13 +0,0 @@
[submodule "taming-transformers"]
path = src/taming-transformers
url = https://github.com/CompVis/taming-transformers.git
ignore = dirty
[submodule "clip"]
path = src/clip
url = https://github.com/openai/CLIP.git
ignore = dirty
[submodule "k-diffusion"]
path = src/k-diffusion
url = https://github.com/lstein/k-diffusion.git
ignore = dirty

174
README.md
View File

@ -1,9 +1,7 @@
# Stable Diffusion Dream Script
This is a fork of CompVis/stable-diffusion, the wonderful open source
text-to-image generator.
The original has been modified in several ways:
text-to-image generator. The original has been improved in several ways:
## Interactive command-line interface similar to the Discord bot
@ -31,13 +29,7 @@ runs from the command-line (CMD or Terminal window), and does not have a GUI.
(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py
* Initializing, be patient...
Loading model from models/ldm/text2img-large/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 872.30 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loading Bert tokenizer from "models/bert"
setting sampler to plms
(...more initialization messages...)
* Initialization done! Awaiting your command...
dream> ashley judd riding a camel -n2 -s150
@ -88,6 +80,142 @@ You may also pass a -v<count> option to generate count variants on the original
passing the first generated image back into img2img the requested number of times. It generates interesting
variants.
## GFPGAN Support
This script also provides the ability to invoke GFPGAN after image
generation. Doing so will enhance faces and optionally upscale the
image to a higher resolution.
To use the ability, clone the [GFPGAN
repository](https://github.com/TencentARC/GFPGAN) and follow their
installation instructions. By default, we expect GFPGAN to be
installed in a 'GFPGAN' sibling directory. Be sure that the "ldm"
conda environment is active as you install GFPGAN.
You may also want to install Real-ESRGAN, if you want to enhance
non-face regions in the image, by installing the pip Real-ESRGAN
package.
```
pip install realesrgan
```
Users whose GPU machines are isolated from the Internet (e.g. on a
University cluster) should be aware that the first time you run
dream.py with GFPGAN turned on, it will try to download model files
from the Internet. To rectify this, you may run `python3
scripts/preload_models.pl` after you have installed GFPGAN and all its
dependencies.
Now, you can run this script by adding the **--gfpgan** option. Any
issues with GFPGAN will be reported on initialization.
~~~~
(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py --gfpgan
* Initializing, be patient...
(...more initialization messages...)
* --gfpgan was specified, loading gfpgan...
(...even more initialization messages...)
* Initialization done! Awaiting your command...
~~~~
When generating prompts, add a -G or --gfpgan_strenth option to
control the strength of the GFPGAN enhancement. 0.0 is no
enhancement, 1.0 is maximum enhancement.
So for instance, to apply the maximum strength:
~~~~
dream> a man wearing a pineapple hat -G 1
~~~~
This also works with img2img:
~~~
dream> a man wearing a pineapple hat -I path/to/your/file.png -G 1
~~~
That's it!
There's also a bunch of options to control GFPGAN settings when
starting the script for different configs that you can read about in
the help text. This will let you control where GFPGAN is installed, if
upsampling is enabled, the upsampler to use and the model path.
By default, images will be upscaled by 2-fold, meaning that the old
Stable Diffusion default size of 512x512 will now be a glorious
detailed 1024x1024. The extent of upscaling is set when you run the
script, and can't be changed while it's running. However, at any time
you may specify **-G0** to turn off upscaling and facial enhancement
for that image or set of images.
Note that loading GFPGAN consumes additional GPU memory, and will add
a few seconds to image generation. However, if can afford a 3090s with
24Gi, the results are well worth it.
## Barebones Web Server
As of version 1.10, this distribution comes with a bare bones web
server (see screenshot). To use it, run the command:
~~~~
(ldm) ~/stable-diffusion$ python3 scripts/dream_web.py
~~~~
You can then connect to the server by pointing your web browser at
http://localhost:9090, or to the network name or IP address of the server.
Kudos to [Tesseract Cat](https://github.com/TesseractCat) for
contributing this code.
![Dream Web Server](static/dream_web_server.png)
## Reading Prompts from a File
You can automate dream.py by providing a text file with the prompts
you want to run, one line per prompt. The text file must be composed
with a text editor (e.g. Notepad) and not a word processor. Each line
should look like what you would type at the dream> prompt:
~~~~
a beautiful sunny day in the park, children playing -n4 -C10
stormy weather on a mountain top, goats grazing -s100
innovative packaging for a squid's dinner -S137038382
~~~~
Then pass this file's name to dream.py when you invoke it:
~~~~
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file="path/to/prompts.txt"
~~~~
## Shortcut for reusing seeds from the previous command
Since it is so common to reuse seeds while refining a prompt, there is
now a shortcut as of version 1.11. Provide a **-S** (or **--seed**)
switch of -1 to use the seed of the most recent image generated. If
you produced multiple images with the **-n** switch, then you can go
back further using -2, -3, etc. up to the first image generated by the
previous command. Sorry, but you can't go back further than one
command.
Here's an example of using this to do a quick refinement. It also
illustrates using the new **-G** switch to turn on upscaling and
face enhancement (see previous section):
~~~~
dream> a cute child playing hopscotch -G0.5
[...]
outputs/img-samples/000039.3498014304.png: "a cute child playing hopscotch" -s50 -b1 -W512 -H512 -C7.5 -mk_lms -S3498014304
# I wonder what it will look like if I bump up the steps and set facial enhancement to full strength?
dream> a cute child playing hopscotch -G1.0 -s100 -S -1
reusing previous seed 3498014304
[...]
outputs/img-samples/000040.3498014304.png: "a cute child playing hopscotch" -G1.0 -s100 -b1 -W512 -H512 -C7.5 -mk_lms -S3498014304
~~~~
## Weighted Prompts
You may weight different sections of the prompt to tell the sampler to attach different levels of
@ -170,10 +298,24 @@ repository and associated paper for details and limitations.
## Changes
* v1.11 (26 August 2022)
* NEW FEATURE: Support upscaling and face enhancement using the GFPGAN module. (kudos to [Oceanswave](https://github.com/Oceanswave)
* You now can specify a seed of -1 to use the previous image's seed, -2 to use the seed for the image generated before that, etc.
Seed memory only extends back to the previous command, but will work on all images generated with the -n# switch.
* Variant generation support temporarily disabled pending more general solution.
* Created a feature branch named **yunsaki-morphing-dream** which adds experimental support for
iteratively modifying the prompt and its parameters. Please see[ Pull Request #86](https://github.com/lstein/stable-diffusion/pull/86)
for a synopsis of how this works. Note that when this feature is eventually added to the main branch, it will may be modified
significantly.
* v1.10 (25 August 2022)
* A barebones but fully functional interactive web server for online generation of txt2img and img2img.
* v1.09 (24 August 2022)
* A new -v option allows you to generate multiple variants of an initial image
in img2img mode. (kudos to Oceanswave)
* Added ability to personalize text to image generation (kudos to nicolai256)
in img2img mode. (kudos to [Oceanswave](https://github.com/Oceanswave). [
See this discussion in the PR for examples and details on use](https://github.com/lstein/stable-diffusion/pull/71#issuecomment-1226700810))
* Added ability to personalize text to image generation (kudos to [Oceanswave](https://github.com/Oceanswave) and [nicolai256](https://github.com/nicolai256))
* Enabled all of the samplers from k_diffusion
* v1.08 (24 August 2022)
@ -457,9 +599,11 @@ to send me an email if you use and like the script.
*Contributions by:*
[Peter Kowalczyk](https://github.com/slix), [Henry Harrison](https://github.com/hwharrison),
[xraxra](https://github.com/xraxra), [bmaltais](https://github.com/bmaltais), [Sean McLellan] (https://github.com/Oceanswave],
[nicolai256](https://github.com/nicolai256], [Benjamin Warner](https://github.com/warner-benjamin),
and [tildebyte](https://github.com/tildebyte)
[xraxra](https://github.com/xraxra), [bmaltais](https://github.com/bmaltais), [Sean McLellan](https://github.com/Oceanswave),
[nicolai256](https://github.com/nicolai256), [Benjamin Warner](https://github.com/warner-benjamin),
[tildebyte](https://github.com/tildebyte),
and [Tesseract Cat](https://github.com/TesseractCat)
Original portions of the software are Copyright (c) 2020 Lincoln D. Stein (https://github.com/lstein)

View File

@ -2,6 +2,7 @@ Feature requests:
1. "gobig" mode - split image into strips, scale up, add detail using
img2img and reassemble with feathering. Issue #66.
See https://github.com/jquesnelle/txt2imghd
2. Port basujindal low VRAM optimizations. Issue #62

View File

@ -10,7 +10,6 @@ dependencies:
- torchvision=0.12.0
- numpy=1.19.2
- pip:
- accelerate==0.12.0
- albumentations==0.4.3
- opencv-python==4.1.2.30
- pudb==2019.2

View File

@ -1,11 +1,17 @@
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from torch.utils.data import (
Dataset,
ConcatDataset,
ChainDataset,
IterableDataset,
)
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
"""
Define an interface to make the IterableDatasets for text2img data chainable
'''
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
@ -13,11 +19,13 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
pass

View File

@ -11,24 +11,34 @@ from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import (
str_to_indices,
give_synsets_from_indices,
download,
retrieve,
)
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
from ldm.modules.image_degradation import (
degradation_fn_bsr,
degradation_fn_bsr_light,
)
def synset2idx(path_to_yaml="data/index_synset.yaml"):
def synset2idx(path_to_yaml='data/index_synset.yaml'):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.keep_orig_class_label = self.config.get(
'keep_orig_class_label', False
)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
@ -46,17 +56,23 @@ class ImageNetBase(Dataset):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
ignore = set(
[
'n06596364_9591.JPEG',
]
)
relpaths = [
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
]
if 'sub_indices' in self.config:
indices = str_to_indices(self.config['sub_indices'])
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
syn = rpath.split('/')[0]
if syn in synsets:
files.append(rpath)
return files
@ -65,78 +81,89 @@ class ImageNetBase(Dataset):
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
self.human_dict = os.path.join(self.root, 'synset_human.txt')
if (
not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
self.human2integer = os.path.join(
self.root, 'imagenet1000_clsidx_to_labels.txt'
)
if not os.path.exists(self.human2integer):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
with open(self.human2integer, 'r') as f:
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
value, key = line.split(":")
value, key = line.split(':')
self.human2integer_dict[key] = int(value)
def _load(self):
with open(self.txt_filelist, "r") as f:
with open(self.txt_filelist, 'r') as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
print(
'Removed {} files from filelist during filtering.'.format(
l1 - len(self.relpaths)
)
)
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.synsets = [p.split('/')[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
class_dict = dict(
(synset, i) for i, synset in enumerate(unique_synsets)
)
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
with open(self.human_dict, 'r') as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
'relpath': np.array(self.relpaths),
'synsets': np.array(self.synsets),
'class_label': np.array(self.class_labels),
'human_label': np.array(self.human_labels),
}
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
self.size = retrieve(self.config, 'size', default=256)
self.data = ImagePaths(
self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
NAME = 'ILSVRC2012_train'
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
FILES = [
"ILSVRC2012_img_train.tar",
'ILSVRC2012_img_train.tar',
]
SIZES = [
147897477120,
@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
cachedir = os.environ.get(
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
self.random_crop = retrieve(
self.config, 'ImageNetTrain/random_crop', default=True
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
print('Extracting sub-tars.')
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
subdir = subpath[: -len('.tar')]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
with tarfile.open(subpath, 'r:') as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, 'w') as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
NAME = 'ILSVRC2012_validation'
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
'ILSVRC2012_img_val.tar',
'validation_synset.txt',
]
SIZES = [
6744924160,
@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
cachedir = os.environ.get(
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
self.random_crop = retrieve(
self.config, 'ImageNetValidation/random_crop', default=False
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
with open(vspath, 'r') as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
print('Reorganizing into synset folders')
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase):
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, 'w') as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
def __init__(
self,
size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
@ -296,67 +345,86 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
assert max_crop_f <= 1.0
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
if degradation == 'bsrgan':
self.degradation_process = partial(
degradation_fn_bsr, sf=downscale_f
)
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
elif degradation == 'bsrgan_light':
self.degradation_process = partial(
degradation_fn_bsr_light, sf=downscale_f
)
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
'cv_nearest': cv2.INTER_NEAREST,
'cv_bilinear': cv2.INTER_LINEAR,
'cv_bicubic': cv2.INTER_CUBIC,
'cv_area': cv2.INTER_AREA,
'cv_lanczos': cv2.INTER_LANCZOS4,
'pil_nearest': PIL.Image.NEAREST,
'pil_bilinear': PIL.Image.BILINEAR,
'pil_bicubic': PIL.Image.BICUBIC,
'pil_box': PIL.Image.BOX,
'pil_hamming': PIL.Image.HAMMING,
'pil_lanczos': PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
self.pil_interpolation = degradation.startswith('pil_')
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
self.degradation_process = partial(
TF.resize,
size=self.LR_size,
interpolation=interpolation_fn,
)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
self.degradation_process = albumentations.SmallestMaxSize(
max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = Image.open(example["file_path_"])
image = Image.open(example['file_path_'])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
image = self.cropper(image=image)['image']
image = self.image_rescaler(image=image)['image']
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
@ -364,10 +432,10 @@ class ImageNetSR(Dataset):
LR_image = np.array(LR_image).astype(np.uint8)
else:
LR_image = self.degradation_process(image=image)["image"]
LR_image = self.degradation_process(image=image)['image']
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example
@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
with open('data/imagenet_train_hr_indices.p', 'rb') as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices)
@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
with open('data/imagenet_val_hr_indices.p', 'rb') as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices)

View File

@ -7,30 +7,33 @@ from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
def __init__(
self,
txt_file,
data_root,
size=None,
interpolation='bicubic',
flip_p=0.5,
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
with open(self.data_paths, 'r') as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
'relative_file_path_': [l for l in self.image_paths],
'file_path_': [
os.path.join(self.data_root, l) for l in self.image_paths
],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@ -38,55 +41,86 @@ class LSUNBase(Dataset):
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = Image.open(example['file_path_'])
if not image.mode == 'RGB':
image = image.convert('RGB')
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
super().__init__(
txt_file='data/lsun/church_outdoor_train.txt',
data_root='data/lsun/churches',
**kwargs
)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file='data/lsun/church_outdoor_val.txt',
data_root='data/lsun/churches',
flip_p=flip_p,
**kwargs
)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
super().__init__(
txt_file='data/lsun/bedrooms_train.txt',
data_root='data/lsun/bedrooms',
**kwargs
)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
super().__init__(
txt_file='data/lsun/bedrooms_val.txt',
data_root='data/lsun/bedrooms',
flip_p=flip_p,
**kwargs
)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
super().__init__(
txt_file='data/lsun/cat_train.txt',
data_root='data/lsun/cats',
**kwargs
)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file='data/lsun/cat_val.txt',
data_root='data/lsun/cats',
flip_p=flip_p,
**kwargs
)

View File

@ -72,31 +72,57 @@ imagenet_dual_templates_small = [
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation='bicubic',
flip_p=0.5,
set='train',
placeholder_token='*',
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self._length = self.num_images
self.placeholder_token = placeholder_token
@ -107,17 +133,20 @@ class PersonalizedBase(Dataset):
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
if set == 'train':
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@ -127,34 +156,47 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
placeholder_string = (
f'{self.coarse_class_text} {placeholder_string}'
)
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images])
text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(placeholder_string)
example["caption"] = text
text = random.choice(imagenet_templates_small).format(
placeholder_string
)
example['caption'] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -50,29 +50,55 @@ imagenet_dual_templates_small = [
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation='bicubic',
flip_p=0.5,
set='train',
placeholder_token='*',
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self._length = self.num_images
self.placeholder_token = placeholder_token
@ -80,17 +106,20 @@ class PersonalizedBase(Dataset):
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
if set == 'train':
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@ -100,30 +129,41 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
text = random.choice(imagenet_dual_templates_small).format(
self.placeholder_token, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
example["caption"] = text
text = random.choice(imagenet_templates_small).format(
self.placeholder_token
)
example['caption'] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example

125
ldm/dream/pngwriter.py Normal file
View File

@ -0,0 +1,125 @@
"""
Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
into the PNG. Intended to be subclassable in order to
create more complex naming schemes, including using the
prompt for file/directory names.
PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches.
"""
import os
import re
from math import sqrt, floor, ceil
from PIL import Image, PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
def __init__(self, outdir, prompt=None, batch_size=1):
self.outdir = outdir
self.batch_size = batch_size
self.prompt = prompt
self.filepath = None
self.files_written = []
os.makedirs(outdir, exist_ok=True)
def write_image(self, image, seed):
self.filepath = self.unique_filename(
seed, self.filepath
) # will increment name in some sensible way
try:
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
except IOError as e:
print(e)
self.files_written.append([self.filepath, seed])
def unique_filename(self, seed, previouspath=None):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(filename.split('.', 1)[0])
basecount += 1
if self.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir, filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png', basename)
if not x:
return self.unique_filename(seed, previouspath)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if self.batch_size > 1 or os.path.exists(
os.path.join(self.outdir, filename)
):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(
os.path.join(self.outdir, filename)
)
return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self, image, prompt, path):
info = PngImagePlugin.PngInfo()
info.add_text('Dream', prompt)
image.save(path, 'PNG', pnginfo=info)
def make_grid(self, image_list, rows=None, cols=None):
image_cnt = len(image_list)
if None in (rows, cols):
rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt / rows)
width = image_list[0].width
height = image_list[0].height
grid_img = Image.new('RGB', (width * cols, height * rows))
for r in range(0, rows):
for c in range(0, cols):
i = r * rows + c
grid_img.paste(image_list[i], (c * width, r * height))
return grid_img
class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i
self.opt = opt
def normalize_prompt(self):
"""Normalize the prompt and switches"""
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if opt.gfpgan_strength:
switches.append(f'-G{opt.gfpgan_strength}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)

127
ldm/dream/readline.py Normal file
View File

@ -0,0 +1,127 @@
"""
Readline helper functions for dream.py (linux and mac only).
"""
import os
import re
import atexit
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except:
readline_available = False
class Completer:
def __init__(self, options):
self.options = sorted(options)
return
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img')):
return self._path_completions(text, state, ('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
return self._path_completions(text, state, ())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [
s for s in self.options if s and s.startswith(text)
]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self, text, state, extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I', '', 1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=', '', 1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path) == 0:
matches.append(text + './')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n) > 1:
continue
full_path = os.path.join(dir, n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text), n))
try:
response = matches[state]
except IndexError:
response = None
return response
if readline_available:
readline.set_completer(
Completer(
[
'cd',
'pwd',
'--steps',
'-s',
'--seed',
'-S',
'--iterations',
'-n',
'--batch_size',
'-b',
'--width',
'-W',
'--height',
'-H',
'--cfg_scale',
'-C',
'--grid',
'-g',
'--individual',
'-i',
'--init_img',
'-I',
'--strength',
'-f',
'-v',
'--variants',
]
).complete
)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)

View File

@ -5,32 +5,49 @@ class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_lr}'
)
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
def __init__(
self,
warm_up_steps,
f_min,
f_max,
f_start,
cycle_lengths,
verbosity_interval=0,
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2:
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
f = self.f_min[cycle] + 0.5 * (
self.f_max[cycle] - self.f_min[cycle]
) * (1 + np.cos(t * np.pi))
self.last_f = f
return f
@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2:
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

View File

@ -6,29 +6,32 @@ from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
@ -36,24 +39,34 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@ -66,28 +79,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
print(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
print(f'Missing Keys: {missing}')
print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
@ -115,7 +130,7 @@ class VQModel(pl.LightningModule):
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
@ -125,7 +140,11 @@ class VQModel(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
@ -133,9 +152,11 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach()
return x
@ -147,81 +168,139 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
log_dict_ema = self._validation_step(
batch, batch_idx, suffix='_ema'
)
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind,
)
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
lr_g = self.lr_g_factor * self.learning_rate
print('lr_d', lr_d)
print('lr_g', lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
opt_ae, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
'frequency': 1,
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
opt_disc, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
'frequency': 1,
},
]
return [opt_ae, opt_disc], scheduler
@ -235,7 +314,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
log['inputs'] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
@ -243,21 +322,24 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
log['inputs'] = x
log['reconstructions'] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
@ -283,43 +365,50 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(
2 * ddconfig['z_channels'], 2 * embed_dim, 1
)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
print(f'Restored from {path}')
def encode(self, x):
h = self.encoder(x)
@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
return x
def training_step(self, batch, batch_idx, optimizer_idx):
@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
self.log(
'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x

View File

@ -10,13 +10,13 @@ from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
def disabled_train(self, mode=True):
@ -26,37 +26,49 @@ def disabled_train(self, mode=True):
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.0e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
diffusion_config = natsorted(
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.numd = (
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
)
self.log_time_interval = (
self.diffusion_model.num_timesteps // log_steps
)
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
self.label_key = (
label_key
if not hasattr(self.diffusion_model, 'cond_stage_key')
else self.diffusion_model.cond_stage_key
)
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
assert (
self.label_key is not None
), 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
sd = torch.load(path, map_location='cpu')
if 'state_dict' in list(sd.keys()):
sd = sd['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f'Missing Keys: {missing}')
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
print(f'Unexpected Keys: {unexpected}')
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule):
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config = deepcopy(
self.diffusion_config.params.unet_config.params
)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(
'#####################################################################'
)
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
print(
'#####################################################################'
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(
x.shape[0], t + 1
)
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule):
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
targets = F.interpolate(
targets, size=(h // 2, w // 2), mode='nearest'
)
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
def compute_top_k(self, logits, labels, k, reduction='mean'):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
if reduction == 'mean':
return (
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
)
elif reduction == 'none':
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val'
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
log[f'{log_prefix}/loss'] = loss.mean()
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
logits, targets, k=1, reduction='mean'
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
logits, targets, k=5, reduction='mean'
)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
self.log_dict(
log,
prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
self.log(
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
)
self.log(
'global_step',
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
self.log(
'lr_abs',
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
t = torch.full(
size=(x.shape[0],), fill_value=t, device=self.device
).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
return loss
def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
self.noisy_acc = {
t: {'acc@1': [], 'acc@5': []}
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self):
self.reset_noise_accs()
@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule):
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
self.noisy_acc[t]['acc@1'].append(
self.compute_top_k(logits, targets, k=1, reduction='mean')
)
self.noisy_acc[t]['acc@5'].append(
self.compute_top_k(logits, targets, k=5, reduction='mean')
)
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
optimizer, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
}]
'frequency': 1,
}
]
return [optimizer], scheduler
return optimizer
@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
y = self.get_conditioning(batch)
if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
log['labels'] = y
if ismap(y):
@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = F.one_hot(
logits.argmax(dim=1), num_classes=self.num_classes
)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
pred
)
for key in log:
log[key] = log[key][:N]

View File

@ -5,12 +5,16 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
@ -23,70 +27,122 @@ class DDIMSampler(object):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
@ -94,30 +150,47 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@ -126,17 +199,38 @@ class DDIMSampler(object):
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True)
iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@ -144,18 +238,30 @@ class DDIMSampler(object):
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
@ -164,42 +270,82 @@ class DDIMSampler(object):
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
alphas = (
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@ -217,26 +363,51 @@ class DDIMSampler(object):
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module):
def __init__(self, model):
@ -16,59 +16,69 @@ class CFGDenoiser(nn.Module):
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
def __init__(self, model, schedule='lms', device='cuda', **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
sigmas = self.model.get_sigmas(S)
if x_T:
x = x_T
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
),
None,
)

View File

@ -5,11 +5,15 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
@ -23,103 +27,172 @@ class PLMSSampler(object):
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}')
# print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@ -128,42 +201,81 @@ class PLMSSampler(object):
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running PLMS Sampling with {total_steps} timesteps")
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True)
iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
@ -172,47 +284,95 @@ class PLMSSampler(object):
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
e_t_uncond, e_t = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
alphas = (
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
a_prev = torch.full(
(b, 1, 1, 1), alphas_prev[index], device=device
)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t
* noise_like(x.shape, device, repeat_noise)
* temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@ -231,7 +391,12 @@ class PLMSSampler(object):
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
e_t_prime = (
55 * e_t
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

View File

@ -13,7 +13,7 @@ def exists(val):
def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()
def default(val, d):
@ -45,19 +45,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
@ -74,7 +73,9 @@ def zero_module(module):
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
@ -82,17 +83,28 @@ class LinearAttention(nn.Module):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
q, k, v = rearrange(
qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3,
)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w,
)
return self.to_out(out)
@ -102,26 +114,18 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
@ -131,12 +135,12 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
@ -146,16 +150,18 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@ -163,8 +169,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
@ -175,7 +180,9 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)
)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@ -194,19 +201,37 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
@ -223,29 +248,43 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
@ -258,4 +297,4 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module):
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
@ -97,37 +100,45 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
"""Learned 2x upsampling without padding"""
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
def forward(self,x):
def forward(self, x):
return self.up(x)
@ -140,7 +151,9 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -149,7 +162,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
@ -219,7 +237,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels
if use_scale_shift_norm
else self.out_channels,
),
)
self.out_layers = nn.Sequential(
@ -227,7 +247,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1
)
),
)
@ -238,7 +260,9 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1
)
def forward(self, x, emb):
"""
@ -251,7 +275,6 @@ class ResBlock(TimestepBlock):
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@ -297,7 +320,7 @@ class AttentionBlock(nn.Module):
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
@ -312,8 +335,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@ -340,7 +365,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1
)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
@staticmethod
@ -397,12 +424,14 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
a = th.einsum(
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length)
@staticmethod
@ -461,19 +490,24 @@ class UNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert (
context_dim is not None
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert (
use_spatial_transformer
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
@ -481,10 +515,14 @@ class UNetModel(nn.Module):
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
assert (
num_head_channels != -1
), 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
assert (
num_heads != -1
), 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
@ -545,8 +583,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
@ -554,8 +596,14 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -592,8 +640,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
@ -609,9 +661,15 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResBlock(
ch,
time_embed_dim,
@ -646,8 +704,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
@ -655,8 +717,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
if level and i == num_res_blocks:
@ -673,7 +741,9 @@ class UNetModel(nn.Module):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -682,14 +752,16 @@ class UNetModel(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)
),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
@ -707,7 +779,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@ -718,9 +790,11 @@ class UNetModel(nn.Module):
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
), 'must specify y if and only if the model is class-conditional'
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
@ -768,9 +842,9 @@ class EncoderUNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
pool='adaptive',
*args,
**kwargs
**kwargs,
):
super().__init__()
@ -888,7 +962,7 @@ class EncoderUNetModel(nn.Module):
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
if pool == 'adaptive':
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
@ -896,7 +970,7 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
elif pool == 'attention':
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
@ -905,13 +979,13 @@ class EncoderUNetModel(nn.Module):
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
elif pool == 'spatial':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
elif pool == 'spatial_v2':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
@ -919,7 +993,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self):
"""
@ -942,20 +1016,21 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)

View File

@ -18,15 +18,24 @@ from einops import repeat
from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == 'linear':
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64,
)
** 2
)
elif schedule == "cosine":
elif schedule == 'cosine':
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
@ -34,23 +43,41 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
elif schedule == 'sqrt_linear':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == 'sqrt':
betas = (
torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
ddim_timesteps = (
(
np.linspace(
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
)
)
** 2
).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
@ -60,17 +87,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
def make_ddim_sampling_parameters(
alphacums, ddim_timesteps, eta, verbose=True
):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
alphas_prev = np.asarray(
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
)
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
@ -109,7 +146,9 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if False: # disabled checkpointing to allow requires_grad = False for main model
if (
False
): # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
@ -129,7 +168,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
@ -160,12 +201,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
@ -215,6 +260,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@ -225,7 +271,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
@ -245,15 +291,16 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config
)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
@ -262,6 +309,8 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
return repeat_noise() if repeat else noise()

View File

@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
logtwopi
+ self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().

View File

@ -10,24 +10,30 @@ class LitEma(nn.Module):
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
#remove as '.'-character is not allowed in buffers
s_name = name.replace('.','')
self.m_name2s_name.update({name:s_name})
self.register_buffer(s_name,p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self,model):
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
decay = min(
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay
@ -38,8 +44,12 @@ class LitEma(nn.Module):
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key]
)
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
@ -48,7 +58,9 @@ class LitEma(nn.Module):
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data
)
else:
assert not key in self.m_name2s_name

View File

@ -8,18 +8,29 @@ from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer
from functools import partial
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
DEFAULT_PLACEHOLDER_TOKEN = ['*']
PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids']
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
@ -28,42 +39,54 @@ def get_bert_token_for_string(tokenizer, string):
return token
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0]
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
placeholder_strings=None,
initializer_words=None,
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
**kwargs
self,
embedder,
placeholder_strings=None,
initializer_words=None,
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
**kwargs,
):
super().__init__()
self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
self.initial_embeddings = (
nn.ParameterDict()
) # These should not be optimized
self.progressive_words = progressive_words
self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
if hasattr(
embedder, 'tokenizer'
): # using Stable Diffusion's CLIP encoder
self.is_clip = True
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
get_token_for_string = partial(
get_clip_token_for_string, embedder.tokenizer
)
get_embedding_for_tkn = partial(
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280
else: # using LDM's BERT encoder
else: # using LDM's BERT encoder
self.is_clip = False
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
get_token_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn
)
get_embedding_for_tkn = embedder.transformer.token_emb
token_dim = 1280
@ -71,69 +94,126 @@ class EmbeddingManager(nn.Module):
placeholder_strings.extend(per_img_token_list)
for idx, placeholder_string in enumerate(placeholder_strings):
token = get_token_for_string(placeholder_string)
if initializer_words and idx < len(initializer_words):
init_word_token = get_token_for_string(initializer_words[idx])
with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
init_word_embedding = get_embedding_for_tkn(
init_word_token.cpu()
)
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
token_params = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=True,
)
self.initial_embeddings[
placeholder_string
] = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=False,
)
else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
token_params = torch.nn.Parameter(
torch.rand(
size=(num_vectors_per_token, token_dim),
requires_grad=True,
)
)
self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params
def forward(
self,
tokenized_text,
embedded_text,
self,
tokenized_text,
embedded_text,
):
b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
for (
placeholder_string,
placeholder_token,
) in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
placeholder_embedding = self.string_to_param_dict[
placeholder_string
].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
if (
self.max_vectors_per_token == 1
): # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(
tokenized_text == placeholder_token.to(device)
)
embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices
else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words:
self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
max_step_tokens = (
1 + self.progressive_counter // PROGRESSIVE_SCALE
)
else:
max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
num_vectors_for_token = min(
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if placeholder_rows.nelement() == 0:
continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
sorted_cols, sort_idx = torch.sort(
placeholder_cols, descending=True
)
sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)):
row = sorted_rows[idx]
col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
new_token_row = torch.cat(
[
tokenized_text[row][:col],
placeholder_token.repeat(num_vectors_for_token).to(
device
),
tokenized_text[row][col + 1 :],
],
axis=0,
)[:n]
new_embed_row = torch.cat(
[
embedded_text[row][:col],
placeholder_embedding[:num_vectors_for_token],
embedded_text[row][col + 1 :],
],
axis=0,
)[:n]
embedded_text[row] = new_embed_row
embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row
return embedded_text
def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict,
"string_to_param": self.string_to_param_dict}, ckpt_path)
torch.save(
{
'string_to_token': self.string_to_token_dict,
'string_to_param': self.string_to_param_dict,
},
ckpt_path,
)
def load(self, ckpt_path, full=True):
ckpt = torch.load(ckpt_path, map_location='cpu')
@ -144,8 +224,12 @@ class EmbeddingManager(nn.Module):
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
all_params = torch.cat(
list(self.string_to_param_dict.values()), axis=0
) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(
axis=-1
) # num_placeholders
return param_norm_squared
@ -153,14 +237,19 @@ class EmbeddingManager(nn.Module):
return self.string_to_param_dict.parameters()
def embedding_to_coarse_loss(self):
loss = 0.
loss = 0.0
num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
loss = (
loss
+ (optimized - coarse)
@ (optimized - coarse).T
/ num_embeddings
)
return loss

View File

@ -6,29 +6,39 @@ from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len = None):
def _expand_mask(mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class AbstractEncoder(nn.Module):
def __init__(self):
@ -38,7 +48,6 @@ class AbstractEncoder(nn.Module):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
@ -56,11 +65,17 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(
self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda'
):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
@ -72,27 +87,42 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
from transformers import (
BertTokenizerFast,
) # TODO: add to reuquirements
# Modified to allow to run on non-internet connected compute nodes.
# Model needs to be loaded into cache from an internet-connected machine
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True
)
except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.")
raise SystemExit(
"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine."
)
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
return tokens
@torch.no_grad()
@ -108,53 +138,84 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len
)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout,
)
def forward(self, text, embedding_manager=None):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
z = self.transformer(
tokens, return_embeddings=True, embedding_manager=embedding_manager
)
return z
def encode(self, text, **kwargs):
# output of length 77
return self(text, **kwargs)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
def __init__(
self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
assert method in [
'nearest',
'linear',
'bilinear',
'trilinear',
'bicubic',
'area',
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.interpolator = partial(
torch.nn.functional.interpolate, mode=method
)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
print(
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias
)
def forward(self,x):
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
@ -162,57 +223,83 @@ class SpatialRescaler(nn.Module):
def encode(self, x):
return self(x)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=True)
self.transformer = CLIPTextModel.from_pretrained(version,local_files_only=True)
self.tokenizer = CLIPTokenizer.from_pretrained(
version, local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True
)
self.device = device
self.max_length = max_length
self.freeze()
def embedding_forward(
self,
input_ids = None,
position_ids = None,
inputs_embeds = None,
embedding_manager = None,
) -> torch.Tensor:
self,
input_ids=None,
position_ids=None,
inputs_embeds=None,
embedding_manager=None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
self.transformer.text_model.embeddings.forward = (
embedding_forward.__get__(self.transformer.text_model.embeddings)
)
def encoder_forward(
self,
inputs_embeds,
attention_mask = None,
causal_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
attention_mask=None,
causal_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@ -239,44 +326,61 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
raise ValueError('You have to specify either input_ids')
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
causal_attention_mask = _build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype
).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
attention_mask = _expand_mask(
attention_mask, hidden_states.dtype
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
@ -291,17 +395,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
return self.text_model(
input_ids=input_ids,
@ -310,11 +416,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager = embedding_manager
embedding_manager=embedding_manager,
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
self.transformer.forward = transformer_forward.__get__(
self.transformer
)
def freeze(self):
self.transformer = self.transformer.eval()
@ -322,9 +429,16 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs)
return z
@ -337,9 +451,17 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(
self,
version='ViT-L/14',
device='cuda',
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.model, _ = clip.load(version, jit=False, device='cpu')
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
@ -359,7 +481,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
def encode(self, text):
z = self(text)
if z.ndim==2:
if z.ndim == 2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
@ -367,29 +489,42 @@ class FrozenCLIPTextEmbedder(nn.Module):
class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False,
)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False,
)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
x = kornia.geometry.resize(
x,
(224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
@ -399,7 +534,8 @@ class FrozenClipImageEmbedder(nn.Module):
return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__":
if __name__ == '__main__':
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)

View File

@ -1,2 +1,6 @@
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
from ldm.modules.image_degradation.bsrgan import (
degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
return im[: w - w % sf, : h - h % sf, ...]
"""
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k):
'''
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
Q = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs):
'''
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
"""
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3):
'''
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681},
year={2019}
}
'''
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
@ -328,10 +350,19 @@ def add_blur(img, sf=4):
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
k = fspecial(
'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img
@ -344,7 +375,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
@ -366,19 +401,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -388,28 +430,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@ -418,7 +469,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@ -428,10 +481,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
@ -452,7 +509,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -462,8 +519,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
@ -472,7 +532,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -487,19 +550,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
@ -544,15 +618,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
@ -561,7 +638,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -576,19 +656,33 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
@ -609,12 +703,19 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image":image}
example = {'image': image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
def degradation_bsrgan_plus(
img,
sf=4,
shuffle_prob=0.5,
use_sharp=True,
lq_patchsize=64,
isp_model=None,
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
@ -630,7 +731,7 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
"""
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -645,8 +746,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
shuffle_order[2:6] = random.sample(
shuffle_order[2:6], len(range(2, 6))
)
shuffle_order[9:13] = random.sample(
shuffle_order[9:13], len(range(9, 13))
)
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
@ -689,8 +794,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
print('check the shuffle!')
# resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise
img = add_JPEG_noise(img)
@ -702,29 +810,37 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')
print('hey')
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print('resizing to', h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img)['image']
print(img_lq.shape)
print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
return im[: w - w % sf, : h - h % sf, ...]
"""
@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k):
'''
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
Q = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs):
'''
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
"""
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3):
'''
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681},
year={2019}
}
'''
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
@ -326,16 +348,25 @@ def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2/4
wd = wd/4
wd2 = wd2 / 4
wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
k = fspecial(
'gaussian', random.randint(2, 4) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img
@ -348,7 +379,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
@ -370,19 +405,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -392,28 +434,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@ -422,7 +473,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@ -432,10 +485,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
@ -456,7 +513,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@ -466,8 +523,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
@ -476,7 +536,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -491,19 +554,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
@ -548,15 +622,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
@ -565,7 +642,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@ -583,20 +663,34 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
@ -617,34 +711,41 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
example = {'image': image}
return example
if __name__ == '__main__':
print("hey")
print('hey')
img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
print('resizing to', h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_lq = deg_fn(img)['image']
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)['image']
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@ -6,13 +6,14 @@ import torch
import cv2
from torchvision.utils import make_grid
from datetime import datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
'''
"""
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
@ -20,10 +21,22 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
"""
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
'.tif',
]
def is_image_file(filename):
@ -49,19 +62,19 @@ def surf(Z, cmap='rainbow', figsize=None):
ax3 = plt.axes(projection='3d')
w, h = Z.shape[:2]
xx = np.arange(0,w,1)
yy = np.arange(0,h,1)
xx = np.arange(0, w, 1)
yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X,Y,Z,cmap=cmap)
#ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
ax3.plot_surface(X, Y, Z, cmap=cmap)
# ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
plt.show()
'''
"""
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
"""
def get_image_paths(dataroot):
@ -83,26 +96,26 @@ def _get_paths_from_images(path):
return images
'''
"""
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
"""
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
w1.append(w-p_size)
h1.append(h-p_size)
# print(w1)
# print(h1)
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
w1.append(w - p_size)
h1.append(h - p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i:i+p_size, j:j+p_size,:])
patches.append(img[i : i + p_size, j : j + p_size, :])
else:
patches.append(img)
@ -118,11 +131,21 @@ def imssave(imgs, img_path):
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
new_path = os.path.join(
os.path.dirname(img_path),
img_name + str('_s{:04d}'.format(i)) + '.png',
)
cv2.imwrite(new_path, img)
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
@ -139,15 +162,18 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800,
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
#if original_dataroot == taget_dataroot:
#del img_path
imssave(
patches, os.path.join(taget_dataroot, os.path.basename(img_path))
)
# if original_dataroot == taget_dataroot:
# del img_path
'''
"""
# --------------------------------------------
# makedir
# --------------------------------------------
'''
"""
def mkdir(path):
@ -171,12 +197,12 @@ def mkdir_and_rename(path):
os.makedirs(path)
'''
"""
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
"""
# --------------------------------------------
@ -206,6 +232,7 @@ def imsave(img, img_path):
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
@ -213,7 +240,6 @@ def imwrite(img, img_path):
cv2.imwrite(img_path, img)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
@ -221,7 +247,7 @@ def read_img(path):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255.
img = img.astype(np.float32) / 255.0
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
@ -230,7 +256,7 @@ def read_img(path):
return img
'''
"""
# --------------------------------------------
# image format conversion
# --------------------------------------------
@ -238,7 +264,7 @@ def read_img(path):
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
"""
# --------------------------------------------
@ -248,22 +274,22 @@ def read_img(path):
def uint2single(img):
return np.float32(img/255.)
return np.float32(img / 255.0)
def single2uint(img):
return np.uint8((img.clip(0, 1)*255.).round())
return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img):
return np.float32(img/65535.)
return np.float32(img / 65535.0)
def single2uint16(img):
return np.uint16((img.clip(0, 1)*65535.).round())
return np.uint16((img.clip(0, 1) * 65535.0).round())
# --------------------------------------------
@ -275,14 +301,25 @@ def single2uint16(img):
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
)
# convert 2/3/4-dimensional torch tensor to uint
@ -290,7 +327,7 @@ def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return np.uint8((img*255.0).round())
return np.uint8((img * 255.0).round())
# --------------------------------------------
@ -305,7 +342,12 @@ def single2tensor3(img):
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
# convert torch tensor to single
@ -316,6 +358,7 @@ def tensor2single(img):
return img
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
@ -327,30 +370,48 @@ def tensor2single3(img):
def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.float()
.unsqueeze(0)
.unsqueeze(0)
)
def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
)
# from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
'''
"""
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
"""
tensor = (
tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (
min_max[1] - min_max[0]
) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = make_grid(
tensor, nrow=int(math.sqrt(n_img)), normalize=False
).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
@ -359,14 +420,17 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
img_np = tensor.numpy()
else:
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(
n_dim
)
)
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
'''
"""
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
@ -374,12 +438,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
"""
def augment_img(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
@ -399,8 +462,7 @@ def augment_img(img, mode=0):
def augment_img_tensor4(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
@ -420,8 +482,7 @@ def augment_img_tensor4(img, mode=0):
def augment_img_tensor(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
img_size = img.size()
img_np = img.data.cpu().numpy()
if len(img_size) == 3:
@ -484,11 +545,11 @@ def augment_imgs(img_list, hflip=True, rot=True):
return [_augment(img) for img in img_list]
'''
"""
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
"""
def modcrop(img_in, scale):
@ -497,11 +558,11 @@ def modcrop(img_in, scale):
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r]
img = img[: H - H_r, : W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r, :]
img = img[: H - H_r, : W - W_r, :]
else:
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
return img
@ -511,11 +572,11 @@ def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
img = img[border:h-border, border:w-border]
img = img[border : h - border, border : w - border]
return img
'''
"""
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
@ -523,74 +584,92 @@ def shave(img_in, border=0):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
"""
def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
"""same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
rlt = np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb
"""same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
rlt = np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr
"""bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
rlt = np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
@ -608,11 +687,11 @@ def channel_convert(in_c, tar_type, img_list):
return img_list
'''
"""
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
"""
# --------------------------------------------
@ -620,17 +699,17 @@ def channel_convert(in_c, tar_type, img_list):
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
@ -640,17 +719,17 @@ def calculate_psnr(img1, img2, border=0):
# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
'''calculate SSIM
"""calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
"""
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
if img1.ndim == 2:
return ssim(img1, img2)
@ -658,7 +737,7 @@ def calculate_ssim(img1, img2, border=0):
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
@ -667,8 +746,8 @@ def calculate_ssim(img1, img2, border=0):
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
@ -684,16 +763,17 @@ def ssim(img1, img2):
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
return ssim_map.mean()
'''
"""
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
"""
# matlab 'imresize' function, now only support 'bicubic'
@ -701,11 +781,14 @@ def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
@ -729,8 +812,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
0, P - 1, P
).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
@ -771,7 +855,11 @@ def imresize(img, scale, antialiasing=True):
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = 'cubic'
@ -782,9 +870,11 @@ def imresize(img, scale, antialiasing=True):
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
@ -805,7 +895,11 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension
# symmetric copying
@ -827,7 +921,9 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(
weights_W[i]
)
if need_squeeze:
out_2.squeeze_()
return out_2
@ -846,7 +942,11 @@ def imresize_np(img, scale, antialiasing=True):
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = 'cubic'
@ -857,9 +957,11 @@ def imresize_np(img, scale, antialiasing=True):
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
@ -880,7 +982,11 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension
# symmetric copying
@ -902,7 +1008,9 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(
weights_W[i]
)
if need_squeeze:
out_2.squeeze_()
@ -913,4 +1021,4 @@ if __name__ == '__main__':
print('---')
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
# img_bicubic = imresize_np(img, 1/4)

View File

@ -1 +1 @@
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator

View File

@ -5,13 +5,24 @@ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/
class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss='hinge',
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert disc_loss in ['hinge', 'vanilla']
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
@ -19,42 +30,68 @@ class LPIPSWithDiscriminator(nn.Module):
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_loss = (
hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss
)
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train",
weights=None):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
weights=None,
):
rec_loss = torch.abs(
inputs.contiguous() - reconstructions.contiguous()
)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = (
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
)
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
@ -67,45 +104,72 @@ class LPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
log = {
'{}/total_loss'.format(split): loss.clone().detach().mean(),
'{}/logvar'.format(split): self.logvar.detach(),
'{}/kl_loss'.format(split): kl_loss.detach().mean(),
'{}/nll_loss'.format(split): nll_loss.detach().mean(),
'{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
log = {
'{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
'{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@ -3,21 +3,25 @@ from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.discriminator.model import (
NLayerDiscriminator,
weights_init,
)
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
@ -26,57 +30,76 @@ def adopt_weight(weight, global_step, threshold=0, value=0.):
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
encodings = (
F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x-y)
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x-y), 2)
return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
def __init__(
self,
disc_start,
codebook_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss='hinge',
n_classes=None,
perceptual_loss='lpips',
pixel_loss='l1',
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
assert disc_loss in ['hinge', 'vanilla']
assert perceptual_loss in ['lpips', 'clips', 'dists']
assert pixel_loss in ['l1', 'l2']
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
if perceptual_loss == 'lpips':
print(f'{self.__class__.__name__}: Running with LPIPS.')
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
raise ValueError(
f'Unknown perceptual loss: >> {perceptual_loss} <<'
)
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
if pixel_loss == 'l1':
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
if disc_loss == 'hinge':
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
elif disc_loss == 'vanilla':
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.')
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
@ -84,31 +107,53 @@ class VQLPIPSWithDiscriminator(nn.Module):
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
def forward(
self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
predicted_indices=None,
):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(
inputs.contiguous(), reconstructions.contiguous()
)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
@ -119,49 +164,77 @@ class VQLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
log = {
'{}/total_loss'.format(split): loss.clone().detach().mean(),
'{}/quant_loss'.format(split): codebook_loss.detach().mean(),
'{}/nll_loss'.format(split): nll_loss.detach().mean(),
'{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/p_loss'.format(split): p_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
perplexity, cluster_usage = measure_perplexity(
predicted_indices, self.n_classes
)
log[f'{split}/perplexity'] = perplexity
log[f'{split}/cluster_usage'] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
log = {
'{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
'{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
Intermediates = namedtuple(
'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']
)
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates'
])
LayerIntermediates = namedtuple(
'Intermediates', ['hiddens', 'attn_intermediates']
)
class AbsolutePositionalEmbedding(nn.Module):
@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(
self.inv_freq
)
+ offset
)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers
def exists(val):
return val is not None
@ -64,18 +68,21 @@ def default(val, d):
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(
lambda x: (x[0][len(prefix) :], x[1]),
tuple(kwargs_with_prefix.items()),
)
)
return kwargs_without_prefix, kwargs
@ -139,7 +154,7 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@ -151,7 +166,7 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
@ -173,7 +188,7 @@ class GRUGating(nn.Module):
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
rearrange(residual, 'b n d -> (b n) d'),
)
return gated_output.reshape_as(x)
@ -181,6 +196,7 @@ class GRUGating(nn.Module):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@ -192,19 +208,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
@ -214,23 +229,25 @@ class FeedForward(nn.Module):
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.0,
on_attn=False,
):
super().__init__()
if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
self.scale = dim_head ** -0.5
raise NotImplementedError(
'Check out entmax activation instead of softmax activation!'
)
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
@ -252,7 +269,7 @@ class Attention(nn.Module):
self.sparse_topk = sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
@ -263,20 +280,29 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None,
):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x)
q_input = x
@ -297,23 +323,35 @@ class Attention(nn.Module):
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)
)
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
q_mask = default(
mask, lambda: torch.ones((b, n), device=device).bool()
)
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
k_mask = default(
k_mask,
lambda: torch.ones((b, k.shape[-2]), device=device).bool(),
)
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
mem_k, mem_v = map(
lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v),
)
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True
)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
@ -324,7 +362,9 @@ class Attention(nn.Module):
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
dots = einsum(
'b h i j, h k -> b k i j', dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
@ -336,7 +376,9 @@ class Attention(nn.Module):
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j'
)
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
@ -354,14 +396,16 @@ class Attention(nn.Module):
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
attn = einsum(
'b h i j, h k -> b k i j', attn, self.post_softmax_proj
).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
post_softmax_attn=post_softmax_attn,
)
return self.to_out(out), intermediates
@ -369,28 +413,28 @@ class Attention(nn.Module):
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
@ -403,10 +447,14 @@ class AttentionLayers(nn.Module):
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None
self.pre_norm = pre_norm
@ -438,15 +486,27 @@ class AttentionLayers(nn.Module):
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
assert (
len(default_block) <= par_width
), 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (
par_width - len(default_block)
)
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
assert (
sandwich_coef > 0 and sandwich_coef <= depth
), 'sandwich coefficient should be less than the depth'
layer_types = (
('a',) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ('f',) * sandwich_coef
)
else:
layer_types = default_block * depth
@ -455,7 +515,9 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs
)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
@ -472,21 +534,17 @@ class AttentionLayers(nn.Module):
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([
norm_fn(),
layer,
residual_fn
]))
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False,
**kwargs
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False,
**kwargs,
):
hiddens = []
intermediates = []
@ -495,7 +553,9 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
@ -508,10 +568,22 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
prev_attn=prev_attn, mem=layer_mem)
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == 'f':
out = block(x)
@ -530,8 +602,7 @@ class AttentionLayers(nn.Module):
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates
hiddens=hiddens, attn_intermediates=intermediates
)
return x, intermediates
@ -545,23 +616,24 @@ class Encoder(AttentionLayers):
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.0,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True,
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
@ -571,23 +643,34 @@ class TransformerWrapper(nn.Module):
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.project_emb = (
nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
)
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim)
)
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
@ -597,20 +680,20 @@ class TransformerWrapper(nn.Module):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
embedding_manager=None,
**kwargs
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
embedding_manager=None,
**kwargs,
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
embedded_x = self.token_emb(x)
if embedding_manager:
x = embedding_manager(x, embedded_x)
else:
@ -629,7 +712,9 @@ class TransformerWrapper(nn.Module):
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
@ -638,13 +723,30 @@ class TransformerWrapper(nn.Module):
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
new_mems = (
list(
map(
lambda pair: torch.cat(pair, dim=-2),
zip(mems, hiddens),
)
)
if exists(mems)
else hiddens
)
new_mems = list(
map(
lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems
)
)
return out, new_mems
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = list(
map(
lambda t: t.post_softmax_attn,
intermediates.attn_intermediates,
)
)
return out, attn_maps
return out

File diff suppressed because it is too large Load Diff

View File

@ -13,22 +13,25 @@ from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
lines = '\n'.join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -70,22 +73,26 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config, **kwargs):
if not "target" in config:
if not 'target' in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
elif config == '__is_unconditional__':
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
module, cls = string.rsplit('.', 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@ -101,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
Q.put('Done')
def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
func: callable,
data,
n_proc,
target_data_type='ndarray',
cpu_intensive=True,
use_worker_id=False,
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
if target_data_type == 'ndarray':
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
)
if cpu_intensive:
@ -135,7 +147,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
if target_data_type == 'ndarray':
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@ -149,7 +161,7 @@ def parallel_data_prefetch(
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)]
[data[i : i + step] for i in range(0, len(data), step)]
)
]
processes = []
@ -158,7 +170,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print(f"Start prefetching...")
print(f'Start prefetching...')
import time
start = time.time()
@ -171,13 +183,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
if res == 'Done':
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
print('Exception: ', e)
for p in processes:
p.terminate()
@ -185,7 +197,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
print(f'Prefetching complete. [{time.time() - start} sec.]')
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):

685
main.py

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,3 @@
accelerate==0.12.0
albumentations==0.4.3
einops==0.3.0
huggingface-hub==0.8.1

View File

@ -3,427 +3,513 @@
import argparse
import shlex
import atexit
import os
import sys
import copy
from PIL import Image,PngImagePlugin
# readline unavailable on windows systems
try:
import readline
readline_available = True
except:
readline_available = False
import warnings
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
debugging = False
def main():
''' Initialize command-line parsers and the diffusion model '''
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
opt = arg_parser.parse_args()
if opt.laion400m:
# defaults suitable to the older latent diffusion weights
width = 256
height = 256
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
weights = "models/ldm/text2img-large/model.ckpt"
width = 256
height = 256
config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml'
weights = 'models/ldm/text2img-large/model.ckpt'
else:
# some defaults suitable for stable diffusion weights
width = 512
height = 512
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
width = 512
height = 512
config = 'configs/stable-diffusion/v1-inference.yaml'
weights = 'models/ldm/stable-diffusion-v1/model.ckpt'
# command line history will be stored in a file called "~/.dream_history"
if readline_available:
setup_readline()
print("* Initializing, be patient...\n")
print('* Initializing, be patient...\n')
sys.path.append('.')
from pytorch_lightning import logging
from ldm.simplet2i import T2I
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
t2i = T2I(width=width,
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
embedding_path=opt.embedding_path,
device=opt.device
t2i = T2I(
width=width,
height=height,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
embedding_path=opt.embedding_path,
device=opt.device,
)
# make sure the output directory exists
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
infile = None
try:
if opt.infile is not None:
infile = open(opt.infile,'r')
infile = open(opt.infile, 'r')
except FileNotFoundError as e:
print(e)
exit(-1)
# preload the model
t2i.load_model()
print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...")
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
# load GFPGAN if requested
if opt.use_gfpgan:
print('\n* --gfpgan was specified, loading gfpgan...')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
try:
model_path = os.path.join(
opt.gfpgan_dir, opt.gfpgan_model_path
)
if not os.path.isfile(model_path):
raise Exception(
'GFPGAN model not found at path ' + model_path
)
sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer
bg_upsampler = load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile
)
t2i.gfpgan = GFPGANer(
model_path=model_path,
upscale=opt.gfpgan_upscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception:
import traceback
print('Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..."
)
log_path = os.path.join(opt.outdir, 'dream_log.txt')
with open(log_path, 'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile)
main_loop(t2i, opt.outdir, cmd_parser, log, infile)
log.close()
if infile:
infile.close()
def main_loop(t2i,parser,log,infile):
''' prompt/read/execute loop '''
def main_loop(t2i, outdir, parser, log, infile):
"""prompt/read/execute loop"""
done = False
last_seeds = []
while not done:
try:
command = infile.readline() if infile else input("dream> ")
command = infile.readline() if infile else input('dream> ')
except EOFError:
done = True
break
if infile and len(command)==0:
if infile and len(command) == 0:
done = True
break
if command.startswith(('#','//')):
if command.startswith(('#', '//')):
continue
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'","\\'")
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError as e:
print(str(e))
continue
if len(elements)==0:
if len(elements) == 0:
continue
if elements[0]=='q':
if elements[0] == 'q':
done = True
break
if elements[0]=='cd' and len(elements)>1:
if elements[0] == 'cd' and len(elements) > 1:
if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}")
t2i.outdir=elements[1]
print(f'setting image output directory to {elements[1]}')
outdir = elements[1]
else:
print(f"directory {elements[1]} does not exist")
print(f'directory {elements[1]} does not exist')
continue
if elements[0]=='pwd':
print(f"current output directory is {t2i.outdir}")
if elements[0] == 'pwd':
print(f'current output directory is {outdir}')
continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
if elements[0].startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
elements.pop(0)
# rearrange the arguments to mimic how it works in the Dream bot.
switches = ['']
switches_started = False
for el in elements:
if el[0]=='-' and not switches_started:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][:len(switches[0])-1]
switches[0] = switches[0][: len(switches[0]) - 1]
try:
opt = parser.parse_args(switches)
opt = parser.parse_args(switches)
except SystemExit:
parser.print_help()
continue
if len(opt.prompt)==0:
print("Try again with a prompt!")
if len(opt.prompt) == 0:
print('Try again with a prompt!')
continue
if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try:
opt.seed = last_seeds[opt.seed]
print(f'reusing previous seed {opt.seed}')
except IndexError:
print(f'No previous seed at position {opt.seed} found')
opt.seed = None
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
individual_images = not opt.grid
try:
if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
else:
assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files."
if None not in (opt.width,opt.height):
print('Warning: width and height options are ignored when modifying an init image')
results = t2i.img2img(**vars(opt))
file_writer = PngWriter(outdir, normalized_prompt, opt.batch_size)
callback = file_writer.write_image if individual_images else None
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
results = (
file_writer.files_written if individual_images else image_list
)
if opt.grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results])
filename = file_writer.unique_filename(results[0][1])
seeds = [a[1] for a in results]
results = [[filename, seeds]]
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)
last_seeds = [r[1] for r in results]
except AssertionError as e:
print(e)
continue
except OSError as e:
print(e)
continue
allVariantResults = []
if opt.variants is not None:
print(f"Generating {opt.variants} variant(s)...")
newopt = copy.deepcopy(opt)
newopt.iterations = 1
newopt.variants = None
for r in results:
newopt.init_img = r[0]
print(f"\t generating variant for {newopt.init_img}")
for j in range(0, opt.variants):
try:
variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e:
print(e)
continue
print(f"{opt.variants} Variants generated!")
print('Outputs:')
write_log_message(t2i, normalized_prompt, results, log)
print("Outputs:")
write_log_message(t2i,opt,results,log)
if allVariantResults:
print("Variant outputs:")
for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log)
print("goodbye!")
print('goodbye!')
def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
import torch
# when multiple images are produced in batch, then we keep track of where each starts
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
seeds = [a[1] for a in results]
if batch_size > 1:
seeds = f"(seeds for each batch row: {seeds})"
warnings.warn(
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.'
)
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=model,
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=True,
) # need to set False in CPU mode
else:
seeds = f"(seeds for individual images: {seeds})"
bg_upsampler = None
return bg_upsampler
# variant generation is going to be superseded by a generalized
# "prompt-morph" functionality
# def generate_variants(t2i,outdir,opt,previous_gens):
# variants = []
# print(f"Generating {opt.variants} variant(s)...")
# newopt = copy.deepcopy(opt)
# newopt.iterations = 1
# newopt.variants = None
# for r in previous_gens:
# newopt.init_img = r[0]
# prompt = PromptFormatter(t2i,newopt).normalize_prompt()
# print(f"] generating variant for {newopt.init_img}")
# for j in range(0,opt.variants):
# try:
# file_writer = PngWriter(outdir,prompt,newopt.batch_size)
# callback = file_writer.write_image
# t2i.prompt2image(image_callback=callback,**vars(newopt))
# results = file_writer.files_written
# variants.append([prompt,results])
# except AssertionError as e:
# print(e)
# continue
# print(f'{opt.variants} variants generated')
# return variants
def write_log_message(t2i, prompt, results, logfile):
"""logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata"""
last_seed = None
img_num = 1
seenit = {}
for r in results:
seed = r[1]
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
log_message = f'{r[0]}: {prompt} -S{seed}'
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.write(log_message + '\n')
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
if opt.grid:
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
else:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return switches
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
"--latent_diffusion",
"-l",
dest='laion400m',
action='store_true',
help="fallback to the latent diffusion (laion400m) weights and config")
parser.add_argument("--from_file",
dest='infile',
type=str,
help="if specified, load prompts from this file")
parser.add_argument('-n','--iterations',
type=int,
default=1,
help="number of images to generate")
parser.add_argument('-F','--full_precision',
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
default='k_lms',
help="which sampler to use (k_lms) - can only be set on command line")
parser.add_argument('--outdir',
'-o',
type=str,
default="outputs/img-samples",
help="directory in which to place generated images and a log of prompts and seeds")
parser.add_argument('--embedding_path',
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
parser.add_argument('--device',
'-d',
type=str,
default="cuda",
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible")
parser = argparse.ArgumentParser(
description="Parse script's command line args"
)
parser.add_argument(
'--laion400m',
'--latent_diffusion',
'-l',
dest='laion400m',
action='store_true',
help='fallback to the latent diffusion (laion400m) weights and config',
)
parser.add_argument(
'--from_file',
dest='infile',
type=str,
help='if specified, load prompts from this file',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='number of images to generate',
)
parser.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='use slower full precision math for calculations',
)
parser.add_argument(
'--sampler',
'-m',
dest='sampler_name',
choices=[
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
],
default='k_lms',
help='which sampler to use (k_lms) - can only be set on command line',
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default='outputs/img-samples',
help='directory in which to place generated images and a log of prompts and seeds (outputs/img-samples',
)
parser.add_argument(
'--embedding_path',
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
parser.add_argument(
'--device',
'-d',
type=str,
default='cuda',
help='device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible',
)
# GFPGAN related args
parser.add_argument(
'--gfpgan',
dest='use_gfpgan',
action='store_true',
help='load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory',
)
parser.add_argument(
'--gfpgan_upscale',
type=int,
default=2,
help='The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_bg_upsampler',
type=str,
default='realesrgan',
help='Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_model_path',
type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth',
help='indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_dir',
type=str,
default='../GFPGAN',
help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified',
)
return parser
def create_cmd_parser():
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12')
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform (slower, but will provide seeds for individual images)")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (will not provide seeds for individual images!)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
parser.add_argument('-s', '--steps', type=int, help='number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-b',
'--batch_size',
type=int,
default=1,
help='number of images to produce per sampling (will not provide seeds for individual images!)',
)
parser.add_argument(
'-W', '--width', type=int, help='image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='prompt configuration scale',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=None,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='skip subprompt weight normalization',
)
return parser
if readline_available:
def setup_readline():
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
def load_history():
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self,text,state,extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if __name__ == "__main__":
if __name__ == '__main__':
main()

114
scripts/dream_web.py Normal file
View File

@ -0,0 +1,114 @@
import json
import base64
import mimetypes
import os
from pytorch_lightning import logging
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
print("Loading model...")
from ldm.simplet2i import T2I
model = T2I(sampler_name='k_lms')
# to get rid of annoying warning messages from pytorch
import transformers
transformers.logging.set_verbosity_error()
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
print("Initializing model, be patient...")
model.load_model()
class DreamServer(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
self.wfile.write(content.read())
elif os.path.exists("." + self.path):
mime_type = mimetypes.guess_type(self.path)[0]
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.end_headers()
with open("." + self.path, "rb") as content:
self.wfile.write(content.read())
else:
self.send_response(404)
else:
self.send_response(404)
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt']
initimg = post_data['initimg']
iterations = int(post_data['iterations'])
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
cfgscale = float(post_data['cfgscale'])
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
print(f"Request to generate with prompt: {prompt}")
outputs = []
if initimg is None:
# Run txt2img
outputs = model.txt2img(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
# Run img2img
outputs = model.img2img(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps)
# Remove the temp file
os.remove("./img2img-tmp.png")
print(f"Prompt generated with output: {outputs}")
post_data['initimg'] = '' # Don't send init image back
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8"))
if __name__ == "__main__":
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))
)
# Start server
dream_server = ThreadingHTTPServer(("0.0.0.0", 9090), DreamServer)
print("\n\n* Started Stable Diffusion dream server! Point your browser at http://localhost:9090 or use the host's DNS name or IP address. *")
try:
dream_server.serve_forever()
except KeyboardInterrupt:
pass
dream_server.server_close()

View File

@ -12,7 +12,6 @@ from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
@ -201,8 +200,6 @@ def main():
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@ -251,8 +248,8 @@ def main():
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
@ -279,13 +276,10 @@ def main():
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save:
for x_sample in x_samples_ddim:

View File

@ -5,30 +5,68 @@
# two machines must share a common .cache directory.
import sys
import transformers
import os
import warnings
transformers.logging.set_verbosity_error()
# this will preload the Bert tokenizer fles
print("preloading bert tokenizer...")
print('preloading bert tokenizer...')
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success")
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
# this will download requirements for Kornia
print("preloading Kornia requirements (ignore the warnings)...")
import kornia
print("...success")
print('preloading Kornia requirements (ignore the deprecation warnings)...')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
# doesn't work - probably wrong logger
# logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)
version='openai/clip-vit-large-patch14'
version = 'openai/clip-vit-large-patch14'
print('preloading CLIP model (Ignore the warnings)...')
print('preloading CLIP model (Ignore the deprecation warnings)...')
sys.stdout.flush()
import clip
from transformers import CLIPTokenizer, CLIPTextModel
tokenizer =CLIPTokenizer.from_pretrained(version)
transformer=CLIPTextModel.from_pretrained(version)
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('\n\n...success')
# In the event that the user has installed GFPGAN and also elected to use
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
gfpgan = False
try:
from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib')
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success')
except Exception:
import traceback
print('Error loading GFPGAN:')
print(traceback.format_exc())

@ -1 +0,0 @@
Subproject commit d50d76daa670286dd6cacf3bcd80b5e4823fc8e1

@ -1 +0,0 @@
Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9

@ -1 +0,0 @@
Subproject commit 24268930bf1dce879235a7fddd0b2355b84d7ea6

View File

@ -0,0 +1,61 @@
* {
font-family: 'Arial';
}
#header {
text-decoration: dotted underline;
}
#search {
margin-top: 20vh;
margin-left: auto;
margin-right: auto;
max-width: 800px;
text-align: center;
}
fieldset {
border: none;
}
#fieldset-search {
display: flex;
}
#prompt {
flex-grow: 1;
border-radius: 20px 0px 0px 20px;
padding: 5px 10px 5px 10px;
border: 1px solid black;
border-right: none;
outline: none;
}
#submit {
border-radius: 0px 20px 20px 0px;
padding: 5px 10px 5px 10px;
border: 1px solid black;
}
#results {
text-align: center;
max-width: 1000px;
margin: auto;
padding-top: 10px;
}
img {
cursor: pointer;
height: 30vh;
border-radius: 5px;
margin: 10px;
}
#fieldset-config {
line-height:2em;
}
input[type="number"] {
width: 60px;
}
#seed {
width: 150px;
}
hr {
width: 200px;
}
label {
white-space: nowrap;
}

View File

@ -0,0 +1,48 @@
<html>
<head>
<title>Stable Diffusion Dream Server</title>
<link rel="icon" href="data:,">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="static/dream_web/index.css">
<script src="static/dream_web/index.js"></script>
</head>
<body>
<div id="search">
<h2 id="header">Stable Diffusion Dream Server</h2>
<form id="generate-form" method="post" action="#">
<fieldset id="fieldset-search">
<input type="text" id="prompt" name="prompt">
<input type="submit" id="submit" value="Generate">
</fieldset>
<fieldset id="fieldset-config">
<label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations">
<label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps">
<label for="cfgscale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfgscale" name="cfgscale" step="any">
<span>&bull;</span>
<label title="Set to multiple of 64" for="width">Width:</label>
<input value="512" type="number" id="width" name="width">
<label title="Set to multiple of 64" for="height">Height:</label>
<input value="512" type="number" id="height" name="height">
<br>
<label title="Upload an image to use img2img" for="initimg">Img2Img Init:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset">&olarr;</button>
</fieldset>
</form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
</div>
<hr>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
</div>
</div>
</body>
</html>

101
static/dream_web/index.js Normal file
View File

@ -0,0 +1,101 @@
function toBase64(file) {
return new Promise((resolve, reject) => {
const r = new FileReader();
r.readAsDataURL(file);
r.onload = () => resolve(r.result);
r.onerror = (error) => reject(error);
});
}
function appendOutput(output) {
let outputNode = document.createElement("img");
outputNode.src = output[0];
let outputConfig = output[2];
let altText = output[1].toString() + " | " + outputConfig.prompt;
outputNode.alt = altText;
outputNode.title = altText;
// Reload image config
outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
}
document.querySelector("#seed").value = output[1];
saveFields(document.querySelector("#generate-form"));
});
document.querySelector("#results").prepend(outputNode);
}
function appendOutputs(outputs) {
for (const output of outputs) {
appendOutput(output);
}
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
localStorage.setItem(k, v);
}
}
}
function loadFields(form) {
for (const [k, v] of new FormData(form)) {
const item = localStorage.getItem(k);
if (item != null) {
form.querySelector(`*[name=${k}]`).value = item;
}
}
}
async function generateSubmit(form) {
const prompt = document.querySelector("#prompt").value;
// Convert file data to base64
let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Post as JSON
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (result) => {
let data = await result.json();
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
if (data.outputs.length != 0) {
document.querySelector("#no-results-message")?.remove();
appendOutputs(data.outputs);
} else {
alert("Error occurred while generating.");
}
});
// Disable form while generating
form.querySelector('fieldset').setAttribute('disabled','');
document.querySelector("#prompt").value = `Generating: "${prompt}"`;
}
window.onload = () => {
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
generateSubmit(form);
});
document.querySelector("#generate-form").addEventListener('change', (e) => {
saveFields(e.target.form);
});
document.querySelector("#reset").addEventListener('click', (e) => {
document.querySelector("#seed").value = -1;
saveFields(e.target.form);
});
loadFields(document.querySelector("#generate-form"));
};

BIN
static/dream_web_server.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 536 KiB