mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'lstein:main' into main
This commit is contained in:
commit
35d3f0ed90
@ -1,20 +1,19 @@
|
|||||||
# Apple Silicon Mac Users
|
# Apple Silicon Mac Users
|
||||||
|
|
||||||
Several people have gotten Stable Diffusion to work on Apple Silicon
|
Several people have gotten Stable Diffusion to work on Apple Silicon
|
||||||
Macs using Anaconda. I've gathered up most of their instructions and
|
Macs using Anaconda, miniforge, etc. I've gathered up most of their instructions and
|
||||||
put them in this fork (and readme). I haven't tested anything besides
|
put them in this fork (and readme). Things have moved really fast and so these
|
||||||
Anaconda, and I've read about issues with things like miniforge, so if
|
instructions change often. Hopefully things will settle down a little.
|
||||||
you have an issue that isn't dealt with in this fork then head on over
|
|
||||||
to the [Apple
|
There's several places where people are discussing Apple
|
||||||
Silicon](https://github.com/CompVis/stable-diffusion/issues/25) issue
|
MPS functionality: [the original CompVis
|
||||||
on GitHub (that page is so long that GitHub hides most of it by
|
issue](https://github.com/CompVis/stable-diffusion/issues/25), and generally on
|
||||||
default, so you need to find the hidden part and expand it to view the
|
[lstein's fork](https://github.com/lstein/stable-diffusion/).
|
||||||
whole thing). This fork would not have been possible without the work
|
|
||||||
done by the people on that issue.
|
|
||||||
|
|
||||||
You have to have macOS 12.3 Monterey or later. Anything earlier than that won't work.
|
You have to have macOS 12.3 Monterey or later. Anything earlier than that won't work.
|
||||||
|
|
||||||
BTW, I haven't tested any of this on Intel Macs.
|
BTW, I haven't tested any of this on Intel Macs but I have read that one person
|
||||||
|
got it to work.
|
||||||
|
|
||||||
How to:
|
How to:
|
||||||
|
|
||||||
@ -27,38 +26,41 @@ ln -s /path/to/ckpt/sd-v1-1.ckpt models/ldm/stable-diffusion-v1/model.ckpt
|
|||||||
|
|
||||||
conda env create -f environment-mac.yaml
|
conda env create -f environment-mac.yaml
|
||||||
conda activate ldm
|
conda activate ldm
|
||||||
|
|
||||||
|
python scripts/preload_models.py
|
||||||
|
python scripts/orig_scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
||||||
```
|
```
|
||||||
|
|
||||||
These instructions are identical to the main repo except I added
|
We have not gotten lstein's dream.py to work yet.
|
||||||
environment-mac.yaml because Mac doesn't have cudatoolkit.
|
|
||||||
|
|
||||||
After you follow all the instructions and run txt2img.py you might get several errors. Here's the errors I've seen and found solutions for.
|
After you follow all the instructions and run txt2img.py you might get several errors. Here's the errors I've seen and found solutions for.
|
||||||
|
|
||||||
|
### Is it slow?
|
||||||
|
|
||||||
|
Be sure to specify 1 sample and 1 iteration.
|
||||||
|
|
||||||
|
python ./scripts/txt2img.py --prompt "ocean" --ddim_steps 5 --n_samples 1 --n_iter 1
|
||||||
|
|
||||||
### Doesn't work anymore?
|
### Doesn't work anymore?
|
||||||
|
|
||||||
We are using PyTorch nightly, which includes support for MPS. I don't
|
PyTorch nightly includes support for MPS. Because of this, this setup is
|
||||||
know exactly how Anaconda does updates, but I woke up one morning and
|
inherently unstable. One morning I woke up and it no longer worked no matter
|
||||||
Stable Diffusion crashed and I couldn't think of anything I did that
|
what I did until I switched to miniforge. However, I have another Mac that works
|
||||||
would've changed anything the night before, when it worked. A day and
|
just fine with Anaconda. If you can't get it to work, please search a little
|
||||||
a half later I finally got it working again. I don't know what changed
|
first because many of the errors will get posted and solved. If you can't find
|
||||||
overnight. PyTorch-nightly changes overnight but I'm pretty sure I
|
a solution please [create an issue](https://github.com/lstein/stable-diffusion/issues).
|
||||||
didn't manually update it. Either way, things are probably going to be
|
|
||||||
bumpy on Apple Silicon until PyTorch releases a firm version that we
|
|
||||||
can lock to.
|
|
||||||
|
|
||||||
To manually update to the latest version of PyTorch nightly (which could fix issues), run this command.
|
One debugging step is to update to the latest version of PyTorch nightly.
|
||||||
|
|
||||||
conda install pytorch torchvision torchaudio -c pytorch-nightly
|
conda install pytorch torchvision torchaudio -c pytorch-nightly
|
||||||
|
|
||||||
## Debugging?
|
Or you can clean everything up.
|
||||||
|
|
||||||
Tired of waiting for your renders to finish before you can see if it
|
conda clean --yes --all
|
||||||
works? Reduce the steps! The picture wont look like anything but if it
|
|
||||||
finishes, hey, it works! This could also help you figure out if you've
|
|
||||||
got a memory problem, because I'm betting 1 step doesn't use much
|
|
||||||
memory.
|
|
||||||
|
|
||||||
python ./scripts/txt2img.py --prompt "ocean" --ddim_steps 1
|
Or you can reset Anaconda.
|
||||||
|
|
||||||
|
conda update --force-reinstall -y -n base -c defaults conda
|
||||||
|
|
||||||
### "No module named cv2" (or some other module)
|
### "No module named cv2" (or some other module)
|
||||||
|
|
||||||
@ -83,6 +85,23 @@ globally.
|
|||||||
|
|
||||||
You might also need to install Rust (I mention this again below).
|
You might also need to install Rust (I mention this again below).
|
||||||
|
|
||||||
|
|
||||||
|
### Debugging?
|
||||||
|
|
||||||
|
Tired of waiting for your renders to finish before you can see if it
|
||||||
|
works? Reduce the steps! The image quality will be horrible but at least you'll
|
||||||
|
get quick feedback.
|
||||||
|
|
||||||
|
python ./scripts/txt2img.py --prompt "ocean" --ddim_steps 5 --n_samples 1 --n_iter 1
|
||||||
|
|
||||||
|
### MAC: torch._C' has no attribute '_cuda_resetPeakMemoryStats' #234
|
||||||
|
|
||||||
|
We haven't fixed gotten dream.py to work on Mac yet.
|
||||||
|
|
||||||
|
### OSError: Can't load tokenizer for 'openai/clip-vit-large-patch14'...
|
||||||
|
|
||||||
|
python scripts/preload_models.py
|
||||||
|
|
||||||
### "The operator [name] is not current implemented for the MPS device." (sic)
|
### "The operator [name] is not current implemented for the MPS device." (sic)
|
||||||
|
|
||||||
Example error.
|
Example error.
|
||||||
@ -92,9 +111,7 @@ Example error.
|
|||||||
NotImplementedError: The operator 'aten::index.Tensor' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on [https://github.com/pytorch/pytorch/issues/77764](https://github.com/pytorch/pytorch/issues/77764). As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
|
NotImplementedError: The operator 'aten::index.Tensor' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on [https://github.com/pytorch/pytorch/issues/77764](https://github.com/pytorch/pytorch/issues/77764). As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
|
||||||
```
|
```
|
||||||
|
|
||||||
Just do what it says:
|
The lstein branch includes this fix in [environment-mac.yaml](https://github.com/lstein/stable-diffusion/blob/main/environment-mac.yaml).
|
||||||
|
|
||||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
|
||||||
|
|
||||||
### "Could not build wheels for tokenizers"
|
### "Could not build wheels for tokenizers"
|
||||||
|
|
||||||
@ -104,6 +121,8 @@ I have not seen this error because I had Rust installed on my computer before I
|
|||||||
|
|
||||||
### How come `--seed` doesn't work?
|
### How come `--seed` doesn't work?
|
||||||
|
|
||||||
|
First this:
|
||||||
|
|
||||||
> Completely reproducible results are not guaranteed across PyTorch
|
> Completely reproducible results are not guaranteed across PyTorch
|
||||||
releases, individual commits, or different platforms. Furthermore,
|
releases, individual commits, or different platforms. Furthermore,
|
||||||
results may not be reproducible between CPU and GPU executions, even
|
results may not be reproducible between CPU and GPU executions, even
|
||||||
@ -111,7 +130,8 @@ when using identical seeds.
|
|||||||
|
|
||||||
[PyTorch docs](https://pytorch.org/docs/stable/notes/randomness.html)
|
[PyTorch docs](https://pytorch.org/docs/stable/notes/randomness.html)
|
||||||
|
|
||||||
There is an [open issue](https://github.com/pytorch/pytorch/issues/78035) (as of August 2022) in pytorch regarding gradient inconsistency. I am guessing that's what is causing this.
|
Second, we might have a fix that at least gets a consistent seed sort of. We're
|
||||||
|
still working on it.
|
||||||
|
|
||||||
### libiomp5.dylib error?
|
### libiomp5.dylib error?
|
||||||
|
|
||||||
@ -137,6 +157,8 @@ sort). [There's more
|
|||||||
suggestions](https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial),
|
suggestions](https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial),
|
||||||
like uninstalling tensorflow and reinstalling. I haven't tried them.
|
like uninstalling tensorflow and reinstalling. I haven't tried them.
|
||||||
|
|
||||||
|
Since I switched to miniforge I haven't seen the error.
|
||||||
|
|
||||||
### Not enough memory.
|
### Not enough memory.
|
||||||
|
|
||||||
This seems to be a common problem and is probably the underlying
|
This seems to be a common problem and is probably the underlying
|
||||||
@ -174,10 +196,10 @@ Actually, this could be happening because there's not enough RAM. You could try
|
|||||||
|
|
||||||
### My images come out black
|
### My images come out black
|
||||||
|
|
||||||
I haven't solved this issue. I just throw away my black
|
We might have this fixed, we are still testing.
|
||||||
images. There's a [similar
|
|
||||||
issue](https://github.com/CompVis/stable-diffusion/issues/69) on CUDA
|
There's a [similar issue](https://github.com/CompVis/stable-diffusion/issues/69)
|
||||||
GPU's where the images come out green. Maybe it's the same issue?
|
on CUDA GPU's where the images come out green. Maybe it's the same issue?
|
||||||
Someone in that issue says to use "--precision full", but this fork
|
Someone in that issue says to use "--precision full", but this fork
|
||||||
actually disables that flag. I don't know why, someone else provided
|
actually disables that flag. I don't know why, someone else provided
|
||||||
that code and I don't know what it does. Maybe the `model.half()`
|
that code and I don't know what it does. Maybe the `model.half()`
|
||||||
@ -204,25 +226,4 @@ What? Intel? On an Apple Silicon?
|
|||||||
The processor must support the Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) instructions.
|
The processor must support the Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) instructions.
|
||||||
The processor must support the Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
|
The processor must support the Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
|
||||||
|
|
||||||
This fixed it for me:
|
This was actually the issue that I couldn't solve until I switched to miniforge.
|
||||||
|
|
||||||
conda clean --yes --all
|
|
||||||
|
|
||||||
### Still slow?
|
|
||||||
|
|
||||||
I changed the defaults of n_samples and n_iter to 1 so that it uses
|
|
||||||
less RAM and makes less images so it will be faster the first time you
|
|
||||||
use it. I don't actually know what n_samples does internally, but I
|
|
||||||
know it consumes a lot more RAM. The n_iter flag just loops around the
|
|
||||||
image creation code, so it shouldn't consume more RAM (it should be
|
|
||||||
faster if you're going to do multiple images because the libraries and
|
|
||||||
model will already be loaded--use a prompt file to get this speed
|
|
||||||
boost).
|
|
||||||
|
|
||||||
These flags are the default sample and iter settings in this fork/branch:
|
|
||||||
|
|
||||||
~~~~
|
|
||||||
python scripts/txt2img.py --prompt "ocean" --n_samples=1 --n_iter=1
|
|
||||||
~~~
|
|
||||||
|
|
||||||
|
|
@ -605,7 +605,7 @@ This will bring your local copy into sync with the remote one.
|
|||||||
|
|
||||||
## Macintosh
|
## Macintosh
|
||||||
|
|
||||||
See (README-Mac-MPS)[README-Mac-MPS.md] for instructions.
|
See [README-Mac-MPS](README-Mac-MPS.md) for instructions.
|
||||||
|
|
||||||
# Simplified API for text to image generation
|
# Simplified API for text to image generation
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ model:
|
|||||||
ddconfig:
|
ddconfig:
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 512
|
resolution: 256
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
out_ch: 3
|
out_ch: 3
|
||||||
ch: 128
|
ch: 128
|
||||||
@ -74,7 +74,7 @@ data:
|
|||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
params:
|
params:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 16
|
num_workers: 2
|
||||||
wrap: false
|
wrap: false
|
||||||
train:
|
train:
|
||||||
target: ldm.data.personalized.PersonalizedBase
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
@ -105,4 +105,5 @@ lightning:
|
|||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
benchmark: True
|
benchmark: True
|
||||||
max_steps: 6100
|
max_steps: 4000
|
||||||
|
|
@ -1,3 +1,4 @@
|
|||||||
|
from math import sqrt, floor, ceil
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
class InitImageResizer():
|
class InitImageResizer():
|
||||||
@ -49,6 +50,26 @@ class InitImageResizer():
|
|||||||
new_image = Image.new('RGB',(width,height))
|
new_image = Image.new('RGB',(width,height))
|
||||||
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
|
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
|
||||||
|
|
||||||
|
print(f'>> Resized image size to {width}x{height}')
|
||||||
|
|
||||||
return new_image
|
return new_image
|
||||||
|
|
||||||
|
def make_grid(image_list, rows=None, cols=None):
|
||||||
|
image_cnt = len(image_list)
|
||||||
|
if None in (rows, cols):
|
||||||
|
rows = floor(sqrt(image_cnt)) # try to make it square
|
||||||
|
cols = ceil(image_cnt / rows)
|
||||||
|
width = image_list[0].width
|
||||||
|
height = image_list[0].height
|
||||||
|
|
||||||
|
grid_img = Image.new('RGB', (width * cols, height * rows))
|
||||||
|
i = 0
|
||||||
|
for r in range(0, rows):
|
||||||
|
for c in range(0, cols):
|
||||||
|
if i >= len(image_list):
|
||||||
|
break
|
||||||
|
grid_img.paste(image_list[i], (c * width, r * height))
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
return grid_img
|
||||||
|
|
||||||
|
@ -2,95 +2,42 @@
|
|||||||
Two helper classes for dealing with PNG images and their path names.
|
Two helper classes for dealing with PNG images and their path names.
|
||||||
PngWriter -- Converts Images generated by T2I into PNGs, finds
|
PngWriter -- Converts Images generated by T2I into PNGs, finds
|
||||||
appropriate names for them, and writes prompt metadata
|
appropriate names for them, and writes prompt metadata
|
||||||
into the PNG. Intended to be subclassable in order to
|
into the PNG.
|
||||||
create more complex naming schemes, including using the
|
|
||||||
prompt for file/directory names.
|
|
||||||
PromptFormatter -- Utility for converting a Namespace of prompt parameters
|
PromptFormatter -- Utility for converting a Namespace of prompt parameters
|
||||||
back into a formatted prompt string with command-line switches.
|
back into a formatted prompt string with command-line switches.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from math import sqrt, floor, ceil
|
from PIL import PngImagePlugin
|
||||||
from PIL import Image, PngImagePlugin
|
|
||||||
|
|
||||||
# -------------------image generation utils-----
|
# -------------------image generation utils-----
|
||||||
|
|
||||||
|
|
||||||
class PngWriter:
|
class PngWriter:
|
||||||
def __init__(self, outdir, prompt=None):
|
def __init__(self, outdir):
|
||||||
self.outdir = outdir
|
self.outdir = outdir
|
||||||
self.prompt = prompt
|
|
||||||
self.filepath = None
|
|
||||||
self.files_written = []
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
def write_image(self, image, seed, upscaled=False):
|
# gives the next unique prefix in outdir
|
||||||
self.filepath = self.unique_filename(
|
def unique_prefix(self):
|
||||||
seed, upscaled, self.filepath
|
# sort reverse alphabetically until we find max+1
|
||||||
) # will increment name in some sensible way
|
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
||||||
try:
|
# find the first filename that matches our pattern or return 000000.0.png
|
||||||
prompt = f'{self.prompt} -S{seed}'
|
existing_name = next(
|
||||||
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
|
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
|
||||||
except IOError as e:
|
'0000000.0.png',
|
||||||
print(e)
|
)
|
||||||
if not upscaled:
|
basecount = int(existing_name.split('.', 1)[0]) + 1
|
||||||
self.files_written.append([self.filepath, seed])
|
return f'{basecount:06}'
|
||||||
|
|
||||||
def unique_filename(self, seed, upscaled=False, previouspath=None):
|
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||||
revision = 1
|
# returns full path of output
|
||||||
|
def save_image_and_prompt_to_png(self, image, prompt, name):
|
||||||
if previouspath is None:
|
path = os.path.join(self.outdir, name)
|
||||||
# sort reverse alphabetically until we find max+1
|
|
||||||
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
|
||||||
# find the first filename that matches our pattern or return 000000.0.png
|
|
||||||
filename = next(
|
|
||||||
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
|
|
||||||
'0000000.0.png',
|
|
||||||
)
|
|
||||||
basecount = int(filename.split('.', 1)[0])
|
|
||||||
basecount += 1
|
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
|
||||||
return os.path.join(self.outdir, filename)
|
|
||||||
|
|
||||||
else:
|
|
||||||
basename = os.path.basename(previouspath)
|
|
||||||
x = re.match('^(\d+)\..*\.png', basename)
|
|
||||||
if not x:
|
|
||||||
return self.unique_filename(seed, upscaled, previouspath)
|
|
||||||
|
|
||||||
basecount = int(x.groups()[0])
|
|
||||||
series = 0
|
|
||||||
finished = False
|
|
||||||
while not finished:
|
|
||||||
series += 1
|
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
|
||||||
path = os.path.join(self.outdir, filename)
|
|
||||||
finished = not os.path.exists(path)
|
|
||||||
return os.path.join(self.outdir, filename)
|
|
||||||
|
|
||||||
def save_image_and_prompt_to_png(self, image, prompt, path):
|
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text('Dream', prompt)
|
info.add_text('Dream', prompt)
|
||||||
image.save(path, 'PNG', pnginfo=info)
|
image.save(path, 'PNG', pnginfo=info)
|
||||||
|
return path
|
||||||
def make_grid(self, image_list, rows=None, cols=None):
|
|
||||||
image_cnt = len(image_list)
|
|
||||||
if None in (rows, cols):
|
|
||||||
rows = floor(sqrt(image_cnt)) # try to make it square
|
|
||||||
cols = ceil(image_cnt / rows)
|
|
||||||
width = image_list[0].width
|
|
||||||
height = image_list[0].height
|
|
||||||
|
|
||||||
grid_img = Image.new('RGB', (width * cols, height * rows))
|
|
||||||
i = 0
|
|
||||||
for r in range(0, rows):
|
|
||||||
for c in range(0, cols):
|
|
||||||
if i>=len(image_list):
|
|
||||||
break
|
|
||||||
grid_img.paste(image_list[i], (c * width, r * height))
|
|
||||||
i = i + 1
|
|
||||||
|
|
||||||
return grid_img
|
|
||||||
|
|
||||||
|
|
||||||
class PromptFormatter:
|
class PromptFormatter:
|
||||||
|
@ -65,6 +65,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
post_data = json.loads(self.rfile.read(content_length))
|
post_data = json.loads(self.rfile.read(content_length))
|
||||||
prompt = post_data['prompt']
|
prompt = post_data['prompt']
|
||||||
initimg = post_data['initimg']
|
initimg = post_data['initimg']
|
||||||
|
strength = float(post_data['strength'])
|
||||||
iterations = int(post_data['iterations'])
|
iterations = int(post_data['iterations'])
|
||||||
steps = int(post_data['steps'])
|
steps = int(post_data['steps'])
|
||||||
width = int(post_data['width'])
|
width = int(post_data['width'])
|
||||||
@ -88,24 +89,24 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
images_generated = 0 # helps keep track of when upscaling is started
|
images_generated = 0 # helps keep track of when upscaling is started
|
||||||
images_upscaled = 0 # helps keep track of when upscaling is completed
|
images_upscaled = 0 # helps keep track of when upscaling is completed
|
||||||
pngwriter = PngWriter(
|
pngwriter = PngWriter("./outputs/img-samples/")
|
||||||
"./outputs/img-samples/", config['prompt'], 1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
prefix = pngwriter.unique_prefix()
|
||||||
# if upscaling is requested, then this will be called twice, once when
|
# if upscaling is requested, then this will be called twice, once when
|
||||||
# the images are first generated, and then again when after upscaling
|
# the images are first generated, and then again when after upscaling
|
||||||
# is complete. The upscaling replaces the original file, so the second
|
# is complete. The upscaling replaces the original file, so the second
|
||||||
# entry should not be inserted into the image list.
|
# entry should not be inserted into the image list.
|
||||||
def image_done(image, seed, upscaled=False):
|
def image_done(image, seed, upscaled=False):
|
||||||
pngwriter.write_image(image, seed, upscaled)
|
name = f'{prefix}.{seed}.png'
|
||||||
|
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
|
||||||
|
|
||||||
# Append post_data to log, but only once!
|
# Append post_data to log, but only once!
|
||||||
if not upscaled:
|
if not upscaled:
|
||||||
current_image = pngwriter.files_written[-1]
|
|
||||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||||
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
|
log.write(f"{path}: {json.dumps(config)}\n")
|
||||||
|
|
||||||
self.wfile.write(bytes(json.dumps(
|
self.wfile.write(bytes(json.dumps(
|
||||||
{'event':'result', 'files':current_image, 'config':config}
|
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
|
||||||
) + '\n',"utf-8"))
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
# control state of the "postprocessing..." message
|
# control state of the "postprocessing..." message
|
||||||
@ -129,22 +130,24 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
|
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
|
||||||
) + '\n',"utf-8"))
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
# TODO: refactor PngWriter:
|
step_writer = PngWriter('./outputs/intermediates/')
|
||||||
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
|
step_index = 1
|
||||||
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
|
|
||||||
def image_progress(sample, step):
|
def image_progress(sample, step):
|
||||||
if self.canceled.is_set():
|
if self.canceled.is_set():
|
||||||
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
|
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
url = None
|
path = None
|
||||||
# since rendering images is moderately expensive, only render every 5th image
|
# since rendering images is moderately expensive, only render every 5th image
|
||||||
# and don't bother with the last one, since it'll render anyway
|
# and don't bother with the last one, since it'll render anyway
|
||||||
|
nonlocal step_index
|
||||||
if progress_images and step % 5 == 0 and step < steps - 1:
|
if progress_images and step % 5 == 0 and step < steps - 1:
|
||||||
image = self.model._sample_to_image(sample)
|
image = self.model._sample_to_image(sample)
|
||||||
step_writer.write_image(image, seed) # TODO PngWriter to return path
|
name = f'{prefix}.{seed}.{step_index}.png'
|
||||||
url = step_writer.filepath
|
metadata = f'{prompt} -S{seed} [intermediate]'
|
||||||
|
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
|
||||||
|
step_index += 1
|
||||||
self.wfile.write(bytes(json.dumps(
|
self.wfile.write(bytes(json.dumps(
|
||||||
{'event':'step', 'step':step + 1, 'url': url}
|
{'event': 'step', 'step': step + 1, 'url': path}
|
||||||
) + '\n',"utf-8"))
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -172,6 +175,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
# Run img2img
|
# Run img2img
|
||||||
self.model.prompt2image(prompt,
|
self.model.prompt2image(prompt,
|
||||||
init_img = "./img2img-tmp.png",
|
init_img = "./img2img-tmp.png",
|
||||||
|
strength = strength,
|
||||||
iterations = iterations,
|
iterations = iterations,
|
||||||
cfg_scale = cfgscale,
|
cfg_scale = cfgscale,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
|
205
ldm/simplet2i.py
205
ldm/simplet2i.py
@ -27,7 +27,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
from ldm.dream.pngwriter import PngWriter
|
from ldm.dream.pngwriter import PngWriter
|
||||||
from ldm.dream.image_util import InitImageResizer
|
|
||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_torch_device
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
@ -157,7 +156,9 @@ class T2I:
|
|||||||
self.latent_diffusion_weights = latent_diffusion_weights
|
self.latent_diffusion_weights = latent_diffusion_weights
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated()
|
# for VRAM usage statistics
|
||||||
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
self.seed = self._new_seed()
|
self.seed = self._new_seed()
|
||||||
else:
|
else:
|
||||||
@ -171,10 +172,15 @@ class T2I:
|
|||||||
Optional named arguments are the same as those passed to T2I and prompt2image()
|
Optional named arguments are the same as those passed to T2I and prompt2image()
|
||||||
"""
|
"""
|
||||||
results = self.prompt2image(prompt, **kwargs)
|
results = self.prompt2image(prompt, **kwargs)
|
||||||
pngwriter = PngWriter(outdir, prompt)
|
pngwriter = PngWriter(outdir)
|
||||||
for r in results:
|
prefix = pngwriter.unique_prefix()
|
||||||
pngwriter.write_image(r[0], r[1])
|
outputs = []
|
||||||
return pngwriter.files_written
|
for image, seed in results:
|
||||||
|
name = f'{prefix}.{seed}.png'
|
||||||
|
path = pngwriter.save_image_and_prompt_to_png(
|
||||||
|
image, f'{prompt} -S{seed}', name)
|
||||||
|
outputs.append([path, seed])
|
||||||
|
return outputs
|
||||||
|
|
||||||
def txt2img(self, prompt, **kwargs):
|
def txt2img(self, prompt, **kwargs):
|
||||||
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
||||||
@ -262,16 +268,9 @@ class T2I:
|
|||||||
assert (
|
assert (
|
||||||
0.0 <= strength <= 1.0
|
0.0 <= strength <= 1.0
|
||||||
), 'can only work with strength in [0.0, 1.0]'
|
), 'can only work with strength in [0.0, 1.0]'
|
||||||
w, h = map(
|
|
||||||
lambda x: x - x % 64, (width, height)
|
|
||||||
) # resize to integer multiple of 64
|
|
||||||
|
|
||||||
if h != height or w != width:
|
if not(width == self.width and height == self.height):
|
||||||
print(
|
width, height, _ = self._resolution_check(width, height, log=True)
|
||||||
f'Height and width must be multiples of 64. Resizing to {h}x{w}.'
|
|
||||||
)
|
|
||||||
height = h
|
|
||||||
width = w
|
|
||||||
|
|
||||||
scope = autocast if self.precision == 'autocast' else nullcontext
|
scope = autocast if self.precision == 'autocast' else nullcontext
|
||||||
|
|
||||||
@ -353,7 +352,7 @@ class T2I:
|
|||||||
image_callback(image, seed)
|
image_callback(image, seed)
|
||||||
else:
|
else:
|
||||||
image_callback(image, seed, upscaled=True)
|
image_callback(image, seed, upscaled=True)
|
||||||
else: # no callback passed, so we simply replace old image with rescaled one
|
else: # no callback passed, so we simply replace old image with rescaled one
|
||||||
result[0] = image
|
result[0] = image
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -366,9 +365,6 @@ class T2I:
|
|||||||
print('Are you sure your system has an adequate NVIDIA GPU?')
|
print('Are you sure your system has an adequate NVIDIA GPU?')
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.session_peakmem = max(
|
|
||||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
|
||||||
)
|
|
||||||
print('Usage stats:')
|
print('Usage stats:')
|
||||||
print(
|
print(
|
||||||
f' {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
f' {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
||||||
@ -377,10 +373,15 @@ class T2I:
|
|||||||
f' Max VRAM used for this generation:',
|
f' Max VRAM used for this generation:',
|
||||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
f' Max VRAM used since script start: ',
|
if self.session_peakmem:
|
||||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
self.session_peakmem = max(
|
||||||
)
|
self.session_peakmem, torch.cuda.max_memory_allocated()
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f' Max VRAM used since script start: ',
|
||||||
|
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -435,7 +436,7 @@ class T2I:
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
strength,
|
strength,
|
||||||
callback, # Currently not implemented for img2img
|
callback, # Currently not implemented for img2img
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt and the initial image
|
An infinite iterator of images from the prompt and the initial image
|
||||||
@ -444,13 +445,13 @@ class T2I:
|
|||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if self.sampler_name != 'ddim':
|
if self.sampler_name != 'ddim':
|
||||||
print(
|
print(
|
||||||
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
|
f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
|
||||||
)
|
)
|
||||||
sampler = DDIMSampler(self.model, device=self.device)
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
else:
|
else:
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
|
|
||||||
init_image = self._load_img(init_img,width,height).to(self.device)
|
init_image = self._load_img(init_img, width, height).to(self.device)
|
||||||
with precision_scope(self.device.type):
|
with precision_scope(self.device.type):
|
||||||
init_latent = self.model.get_first_stage_encoding(
|
init_latent = self.model.get_first_stage_encoding(
|
||||||
self.model.encode_first_stage(init_image)
|
self.model.encode_first_stage(init_image)
|
||||||
@ -486,22 +487,20 @@ class T2I:
|
|||||||
|
|
||||||
uc = self.model.get_learned_conditioning([''])
|
uc = self.model.get_learned_conditioning([''])
|
||||||
|
|
||||||
# weighted sub-prompts
|
# get weighted sub-prompts
|
||||||
subprompts, weights = T2I._split_weighted_subprompts(prompt)
|
weighted_subprompts = T2I._split_weighted_subprompts(
|
||||||
if len(subprompts) > 1:
|
prompt, skip_normalize)
|
||||||
|
|
||||||
|
if len(weighted_subprompts) > 1:
|
||||||
# i dont know if this is correct.. but it works
|
# i dont know if this is correct.. but it works
|
||||||
c = torch.zeros_like(uc)
|
c = torch.zeros_like(uc)
|
||||||
# get total weight for normalizing
|
|
||||||
totalWeight = sum(weights)
|
|
||||||
# normalize each "sub prompt" and add it
|
# normalize each "sub prompt" and add it
|
||||||
for i in range(0, len(subprompts)):
|
for i in range(0, len(weighted_subprompts)):
|
||||||
weight = weights[i]
|
subprompt, weight = weighted_subprompts[i]
|
||||||
if not skip_normalize:
|
self._log_tokenization(subprompt)
|
||||||
weight = weight / totalWeight
|
|
||||||
self._log_tokenization(subprompts[i])
|
|
||||||
c = torch.add(
|
c = torch.add(
|
||||||
c,
|
c,
|
||||||
self.model.get_learned_conditioning([subprompts[i]]),
|
self.model.get_learned_conditioning([subprompt]),
|
||||||
alpha=weight,
|
alpha=weight,
|
||||||
)
|
)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
@ -513,7 +512,8 @@ class T2I:
|
|||||||
x_samples = self.model.decode_first_stage(samples)
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
if len(x_samples) != 1:
|
if len(x_samples) != 1:
|
||||||
raise Exception(f'expected to get a single image, but got {len(x_samples)}')
|
raise Exception(
|
||||||
|
f'expected to get a single image, but got {len(x_samples)}')
|
||||||
x_sample = 255.0 * rearrange(
|
x_sample = 255.0 * rearrange(
|
||||||
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||||
)
|
)
|
||||||
@ -532,7 +532,7 @@ class T2I:
|
|||||||
if self.model is None:
|
if self.model is None:
|
||||||
seed_everything(self.seed)
|
seed_everything(self.seed)
|
||||||
try:
|
try:
|
||||||
config = OmegaConf.load(self.config)
|
config = OmegaConf.load(self.config)
|
||||||
self.device = self._get_device()
|
self.device = self._get_device()
|
||||||
model = self._load_model_from_config(config, self.weights)
|
model = self._load_model_from_config(config, self.weights)
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
@ -544,8 +544,9 @@ class T2I:
|
|||||||
self.model.cond_stage_model.device = self.device
|
self.model.cond_stage_model.device = self.device
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
import traceback
|
import traceback
|
||||||
print('Error loading model. Only the CUDA backend is supported',file=sys.stderr)
|
print(
|
||||||
print(traceback.format_exc(),file=sys.stderr)
|
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
@ -605,65 +606,67 @@ class T2I:
|
|||||||
print(f'image path = {path}, cwd = {os.getcwd()}')
|
print(f'image path = {path}, cwd = {os.getcwd()}')
|
||||||
with Image.open(path) as img:
|
with Image.open(path) as img:
|
||||||
image = img.convert('RGB')
|
image = img.convert('RGB')
|
||||||
print(f'loaded input image of size {image.width}x{image.height} from {path}')
|
print(
|
||||||
|
f'loaded input image of size {image.width}x{image.height} from {path}')
|
||||||
|
|
||||||
image = InitImageResizer(image).resize(width,height)
|
from ldm.dream.image_util import InitImageResizer
|
||||||
print(f'resized input image to size {image.width}x{image.height}')
|
if width == self.width and height == self.height:
|
||||||
|
new_image_width, new_image_height, resize_needed = self._resolution_check(
|
||||||
|
image.width, image.height)
|
||||||
|
else:
|
||||||
|
if height == self.height:
|
||||||
|
new_image_width, new_image_height, resize_needed = self._resolution_check(
|
||||||
|
width, image.height)
|
||||||
|
if width == self.width:
|
||||||
|
new_image_width, new_image_height, resize_needed = self._resolution_check(
|
||||||
|
image.width, height)
|
||||||
|
else:
|
||||||
|
image = InitImageResizer(image).resize(width, height)
|
||||||
|
resize_needed = False
|
||||||
|
if resize_needed:
|
||||||
|
image = InitImageResizer(image).resize(
|
||||||
|
new_image_width, new_image_height)
|
||||||
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
return 2.0 * image - 1.0
|
return 2.0 * image - 1.0
|
||||||
|
|
||||||
def _split_weighted_subprompts(text):
|
def _split_weighted_subprompts(text, skip_normalize=False):
|
||||||
"""
|
"""
|
||||||
grabs all text up to the first occurrence of ':'
|
grabs all text up to the first occurrence of ':'
|
||||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||||
if ':' has no value defined, defaults to 1.0
|
if ':' has no value defined, defaults to 1.0
|
||||||
repeats until no text remaining
|
repeats until no text remaining
|
||||||
"""
|
"""
|
||||||
remaining = len(text)
|
prompt_parser = re.compile("""
|
||||||
prompts = []
|
(?P<prompt> # capture group for 'prompt'
|
||||||
weights = []
|
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||||
while remaining > 0:
|
) # end 'prompt'
|
||||||
if ':' in text:
|
(?: # non-capture group
|
||||||
idx = text.index(':') # first occurrence from start
|
:+ # match one or more ':' characters
|
||||||
# grab up to index as sub-prompt
|
(?P<weight> # capture group for 'weight'
|
||||||
prompt = text[:idx]
|
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||||
remaining -= idx
|
)? # end weight capture group, make optional
|
||||||
# remove from main text
|
\s* # strip spaces after weight
|
||||||
text = text[idx + 1 :]
|
| # OR
|
||||||
# find value for weight
|
$ # else, if no ':' then match end of line
|
||||||
if ' ' in text:
|
) # end non-capture group
|
||||||
idx = text.index(' ') # first occurence
|
""", re.VERBOSE)
|
||||||
else: # no space, read to end
|
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||||
idx = len(text)
|
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||||
if idx != 0:
|
if skip_normalize:
|
||||||
try:
|
return parsed_prompts
|
||||||
weight = float(text[:idx])
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
except: # couldn't treat as float
|
if weight_sum == 0:
|
||||||
print(
|
print(
|
||||||
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
|
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||||
)
|
equal_weight = 1 / len(parsed_prompts)
|
||||||
weight = 1.0
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
else: # no value found
|
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||||
weight = 1.0
|
|
||||||
# remove from main text
|
# shows how the prompt is tokenized
|
||||||
remaining -= idx
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
text = text[idx + 1 :]
|
|
||||||
# append the sub-prompt and its weight
|
|
||||||
prompts.append(prompt)
|
|
||||||
weights.append(weight)
|
|
||||||
else: # no : found
|
|
||||||
if len(text) > 0: # there is still text though
|
|
||||||
# take remainder as weight 1
|
|
||||||
prompts.append(text)
|
|
||||||
weights.append(1.0)
|
|
||||||
remaining = 0
|
|
||||||
return prompts, weights
|
|
||||||
|
|
||||||
# shows how the prompt is tokenized
|
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
def _log_tokenization(self, text):
|
def _log_tokenization(self, text):
|
||||||
if not self.log_tokenization:
|
if not self.log_tokenization:
|
||||||
@ -673,15 +676,35 @@ class T2I:
|
|||||||
discarded = ""
|
discarded = ""
|
||||||
usedTokens = 0
|
usedTokens = 0
|
||||||
totalTokens = len(tokens)
|
totalTokens = len(tokens)
|
||||||
for i in range(0,totalTokens):
|
for i in range(0, totalTokens):
|
||||||
token = tokens[i].replace('</w>',' ')
|
token = tokens[i].replace('</w>', ' ')
|
||||||
# alternate color
|
# alternate color
|
||||||
s = (usedTokens % 6) + 1
|
s = (usedTokens % 6) + 1
|
||||||
if i < self.model.cond_stage_model.max_length:
|
if i < self.model.cond_stage_model.max_length:
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
else: # over max token length
|
else: # over max token length
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||||
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
print(
|
||||||
|
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
def _resolution_check(self, width, height, log=False):
|
||||||
|
resize_needed = False
|
||||||
|
w, h = map(
|
||||||
|
lambda x: x - x % 64, (width, height)
|
||||||
|
) # resize to integer multiple of 64
|
||||||
|
if h != height or w != width:
|
||||||
|
if log:
|
||||||
|
print(
|
||||||
|
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
||||||
|
)
|
||||||
|
height = h
|
||||||
|
width = w
|
||||||
|
resize_needed = True
|
||||||
|
|
||||||
|
if (width * height) > (self.width * self.height):
|
||||||
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
|
|
||||||
|
return width, height, resize_needed
|
||||||
|
@ -12,6 +12,7 @@ import time
|
|||||||
import ldm.dream.readline
|
import ldm.dream.readline
|
||||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
||||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||||
|
from ldm.dream.image_util import make_grid
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Initialize command-line parsers and the diffusion model"""
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
@ -28,7 +29,10 @@ def main():
|
|||||||
width = 512
|
width = 512
|
||||||
height = 512
|
height = 512
|
||||||
config = 'configs/stable-diffusion/v1-inference.yaml'
|
config = 'configs/stable-diffusion/v1-inference.yaml'
|
||||||
weights = 'models/ldm/stable-diffusion-v1/model.ckpt'
|
if '.ckpt' in opt.weights:
|
||||||
|
weights = opt.weights
|
||||||
|
else:
|
||||||
|
weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt'
|
||||||
|
|
||||||
print('* Initializing, be patient...\n')
|
print('* Initializing, be patient...\n')
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
@ -203,24 +207,40 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
|||||||
|
|
||||||
# Here is where the images are actually generated!
|
# Here is where the images are actually generated!
|
||||||
try:
|
try:
|
||||||
file_writer = PngWriter(current_outdir, normalized_prompt)
|
file_writer = PngWriter(current_outdir)
|
||||||
callback = file_writer.write_image if individual_images else None
|
prefix = file_writer.unique_prefix()
|
||||||
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
|
seeds = set()
|
||||||
results = (
|
results = []
|
||||||
file_writer.files_written if individual_images else image_list
|
grid_images = dict() # seed -> Image, only used if `do_grid`
|
||||||
)
|
def image_writer(image, seed, upscaled=False):
|
||||||
|
if do_grid:
|
||||||
|
grid_images[seed] = image
|
||||||
|
else:
|
||||||
|
if upscaled and opt.save_original:
|
||||||
|
filename = f'{prefix}.{seed}.postprocessed.png'
|
||||||
|
else:
|
||||||
|
filename = f'{prefix}.{seed}.png'
|
||||||
|
path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename)
|
||||||
|
if (not upscaled) or opt.save_original:
|
||||||
|
# only append to results if we didn't overwrite an earlier output
|
||||||
|
results.append([path, seed])
|
||||||
|
|
||||||
if do_grid and len(results) > 0:
|
seeds.add(seed)
|
||||||
grid_img = file_writer.make_grid([r[0] for r in results])
|
|
||||||
filename = file_writer.unique_filename(results[0][1])
|
t2i.prompt2image(image_callback=image_writer, **vars(opt))
|
||||||
seeds = [a[1] for a in results]
|
|
||||||
results = [[filename, seeds]]
|
if do_grid and len(grid_images) > 0:
|
||||||
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
|
grid_img = make_grid(list(grid_images.values()))
|
||||||
file_writer.save_image_and_prompt_to_png(
|
first_seed = next(iter(seeds))
|
||||||
|
filename = f'{prefix}.{first_seed}.png'
|
||||||
|
# TODO better metadata for grid images
|
||||||
|
metadata_prompt = f'{normalized_prompt} -S{first_seed}'
|
||||||
|
path = file_writer.save_image_and_prompt_to_png(
|
||||||
grid_img, metadata_prompt, filename
|
grid_img, metadata_prompt, filename
|
||||||
)
|
)
|
||||||
|
results = [[path, seeds]]
|
||||||
|
|
||||||
last_seeds = [r[1] for r in results]
|
last_seeds = list(seeds)
|
||||||
|
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -401,6 +421,11 @@ def create_argv_parser():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='Start in web server mode.',
|
help='Start in web server mode.',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--weights',
|
||||||
|
default='model',
|
||||||
|
help='Indicates the Stable Diffusion model to use.',
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from ldm.modules.encoders.modules import BERTTokenizer
|
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
||||||
from ldm.modules.embedding_manager import EmbeddingManager
|
from ldm.modules.embedding_manager import EmbeddingManager
|
||||||
|
|
||||||
import argparse, os
|
import argparse, os
|
||||||
@ -6,7 +6,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def get_placeholder_loop(placeholder_string, tokenizer):
|
def get_placeholder_loop(placeholder_string, embedder, use_bert):
|
||||||
|
|
||||||
new_placeholder = None
|
new_placeholder = None
|
||||||
|
|
||||||
@ -16,10 +16,36 @@ def get_placeholder_loop(placeholder_string, tokenizer):
|
|||||||
else:
|
else:
|
||||||
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
|
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
|
||||||
|
|
||||||
token = tokenizer(new_placeholder)
|
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
|
||||||
|
|
||||||
|
if token is not None:
|
||||||
|
return new_placeholder, token
|
||||||
|
|
||||||
|
def get_clip_token_for_string(tokenizer, string):
|
||||||
|
batch_encoding = tokenizer(
|
||||||
|
string,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_length=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = batch_encoding["input_ids"]
|
||||||
|
|
||||||
|
if torch.count_nonzero(tokens - 49407) == 2:
|
||||||
|
return tokens[0, 1]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_bert_token_for_string(tokenizer, string):
|
||||||
|
token = tokenizer(string)
|
||||||
|
if torch.count_nonzero(token) == 3:
|
||||||
|
return token[0, 1]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
if torch.count_nonzero(token) == 3:
|
|
||||||
return new_placeholder, token[0, 1]
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
@ -40,10 +66,20 @@ if __name__ == "__main__":
|
|||||||
help="Output path for the merged manager",
|
help="Output path for the merged manager",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-sd", "--use_bert",
|
||||||
|
action="store_true",
|
||||||
|
help="Flag to denote that we are not merging stable diffusion embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
|
if args.use_bert:
|
||||||
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
|
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
||||||
|
else:
|
||||||
|
embedder = FrozenCLIPEmbedder().cuda()
|
||||||
|
|
||||||
|
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
|
||||||
|
|
||||||
string_to_token_dict = {}
|
string_to_token_dict = {}
|
||||||
string_to_param_dict = torch.nn.ParameterDict()
|
string_to_param_dict = torch.nn.ParameterDict()
|
||||||
@ -63,7 +99,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
placeholder_to_src[placeholder_string] = manager_ckpt
|
placeholder_to_src[placeholder_string] = manager_ckpt
|
||||||
else:
|
else:
|
||||||
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
|
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
|
||||||
string_to_token_dict[new_placeholder] = new_token
|
string_to_token_dict[new_placeholder] = new_token
|
||||||
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
||||||
|
|
||||||
@ -77,7 +113,3 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print("Managers merged. Final list of placeholders: ")
|
print("Managers merged. Final list of placeholders: ")
|
||||||
print(placeholder_to_src)
|
print(placeholder_to_src)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,13 +59,14 @@
|
|||||||
<option value="832">832</option> <option value="896">896</option>
|
<option value="832">832</option> <option value="896">896</option>
|
||||||
<option value="960">960</option> <option value="1024">1024</option>
|
<option value="960">960</option> <option value="1024">1024</option>
|
||||||
</select>
|
</select>
|
||||||
<br>
|
|
||||||
<label title="Upload an image to use img2img" for="initimg">Img2Img Init:</label>
|
|
||||||
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
|
|
||||||
<label title="Set to -1 for random seed" for="seed">Seed:</label>
|
<label title="Set to -1 for random seed" for="seed">Seed:</label>
|
||||||
<input value="-1" type="number" id="seed" name="seed">
|
<input value="-1" type="number" id="seed" name="seed">
|
||||||
<button type="button" id="reset-seed">↺</button>
|
<button type="button" id="reset-seed">↺</button>
|
||||||
<span>•</span>
|
<br>
|
||||||
|
<label for="strength">Img2Img Strength:</label>
|
||||||
|
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
|
||||||
|
<label title="Upload an image to use img2img" for="initimg">Init:</label>
|
||||||
|
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
|
||||||
<button type="button" id="reset-all">Reset to Defaults</button>
|
<button type="button" id="reset-all">Reset to Defaults</button>
|
||||||
<br>
|
<br>
|
||||||
<label for="progress_images">Display in-progress images (slows down generation):</label>
|
<label for="progress_images">Display in-progress images (slows down generation):</label>
|
||||||
|
@ -61,8 +61,8 @@ async function generateSubmit(form) {
|
|||||||
let formData = Object.fromEntries(new FormData(form));
|
let formData = Object.fromEntries(new FormData(form));
|
||||||
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||||
|
|
||||||
let strength = 0.75; // TODO let this be specified in the UI
|
let strength = formData.strength;
|
||||||
let totalSteps = formData.initimg ? Math.floor(.75 * formData.steps) : formData.steps;
|
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
|
||||||
|
|
||||||
let progressSectionEle = document.querySelector('#progress-section');
|
let progressSectionEle = document.querySelector('#progress-section');
|
||||||
progressSectionEle.style.display = 'initial';
|
progressSectionEle.style.display = 'initial';
|
||||||
@ -95,10 +95,9 @@ async function generateSubmit(form) {
|
|||||||
if (data.event === 'result') {
|
if (data.event === 'result') {
|
||||||
noOutputs = false;
|
noOutputs = false;
|
||||||
document.querySelector("#no-results-message")?.remove();
|
document.querySelector("#no-results-message")?.remove();
|
||||||
appendOutput(data.files[0],data.files[1],data.config);
|
appendOutput(data.url, data.seed, data.config);
|
||||||
progressEle.setAttribute('value', 0);
|
progressEle.setAttribute('value', 0);
|
||||||
progressEle.setAttribute('max', totalSteps);
|
progressEle.setAttribute('max', totalSteps);
|
||||||
progressImageEle.src = BLANK_IMAGE_URL;
|
|
||||||
} else if (data.event === 'upscaling-started') {
|
} else if (data.event === 'upscaling-started') {
|
||||||
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
||||||
document.getElementById("scaling-inprocess-message").style.display = "block";
|
document.getElementById("scaling-inprocess-message").style.display = "block";
|
||||||
|
Loading…
Reference in New Issue
Block a user