Merge branch 'development' into development

This commit is contained in:
Lincoln Stein 2022-10-27 23:04:04 -04:00 committed by GitHub
commit 387f796ebe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
61 changed files with 3549 additions and 497 deletions

3
.dockerignore Normal file
View File

@ -0,0 +1,3 @@
*
!environment*.yml
!docker-build

39
.github/workflows/build-container.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# Building the Image without pushing to confirm it is still buildable
# confirum functionality would unfortunately need way more resources
name: build container image
on:
push:
branches:
- 'main'
- 'development'
pull_request_target:
branches:
- 'main'
- 'development'
jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Cache Docker layers
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: ${{ runner.os }}-buildx-
- name: Build and push
uses: docker/build-push-action@v3
with:
context: .
file: docker-build/Dockerfile
platforms: linux/amd64
push: false
tags: ${{ github.repository }}:latest
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache

View File

@ -14,7 +14,7 @@ from threading import Event
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.conditioning import split_weighted_subprompts from ldm.invoke.prompt_parser import split_weighted_subprompts
from backend.modules.parameters import parameters_to_command from backend.modules.parameters import parameters_to_command

View File

@ -33,7 +33,7 @@ from ldm.generate import Generate
from ldm.invoke.restoration import Restoration from ldm.invoke.restoration import Restoration
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import split_weighted_subprompts from ldm.invoke.prompt_parser import split_weighted_subprompts
from modules.parameters import parameters_to_command from modules.parameters import parameters_to_command

View File

@ -13,6 +13,13 @@ stable-diffusion-1.4:
width: 512 width: 512
height: 512 height: 512
default: true default: true
inpainting-1.5:
description: runwayML tuned inpainting model v1.5
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
config: configs/stable-diffusion/v1-inpainting-inference.yaml
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
stable-diffusion-1.5: stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -76,4 +76,4 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -0,0 +1,79 @@
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -1,57 +1,74 @@
FROM debian FROM ubuntu AS get_miniconda
ARG gsd
ENV GITHUB_STABLE_DIFFUSION $gsd
ARG rsd
ENV REQS $rsd
ARG cs
ENV CONDA_SUBDIR $cs
ENV PIP_EXISTS_ACTION="w"
# TODO: Optimize image size
SHELL ["/bin/bash", "-c"] SHELL ["/bin/bash", "-c"]
WORKDIR / # install wget
RUN apt update && apt upgrade -y \ RUN apt-get update \
&& apt install -y \ && apt-get install -y \
wget \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# download and install miniconda
ARG conda_version=py39_4.12.0-Linux-x86_64
ARG conda_prefix=/opt/conda
RUN wget --progress=dot:giga -O /miniconda.sh \
https://repo.anaconda.com/miniconda/Miniconda3-${conda_version}.sh \
&& bash /miniconda.sh -b -p ${conda_prefix} \
&& rm -f /miniconda.sh
FROM ubuntu AS invokeai
# use bash
SHELL [ "/bin/bash", "-c" ]
# clean bashrc
RUN echo "" > ~/.bashrc
# Install necesarry packages
RUN apt-get update \
&& apt-get install -y \
--no-install-recommends \
gcc \
git \ git \
libgl1-mesa-glx \ libgl1-mesa-glx \
libglib2.0-0 \ libglib2.0-0 \
pip \ pip \
python3 \ python3 \
&& git clone $GITHUB_STABLE_DIFFUSION python3-dev \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install Anaconda or Miniconda # clone repository and create symlinks
COPY anaconda.sh . ARG invokeai_git=https://github.com/invoke-ai/InvokeAI.git
RUN bash anaconda.sh -b -u -p /anaconda && /anaconda/bin/conda init bash ARG project_name=invokeai
RUN git clone ${invokeai_git} /${project_name} \
&& mkdir /${project_name}/models/ldm/stable-diffusion-v1 \
&& ln -s /data/models/sd-v1-4.ckpt /${project_name}/models/ldm/stable-diffusion-v1/model.ckpt \
&& ln -s /data/outputs/ /${project_name}/outputs
# SD # set workdir
WORKDIR /stable-diffusion WORKDIR /${project_name}
RUN source ~/.bashrc \
&& conda create -y --name ldm && conda activate ldm \
&& conda config --env --set subdir $CONDA_SUBDIR \
&& pip3 install -r $REQS \
&& pip3 install basicsr facexlib realesrgan \
&& mkdir models/ldm/stable-diffusion-v1 \
&& ln -s "/data/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
# Face restoreation # install conda env and preload models
# by default expected in a sibling directory to stable-diffusion ARG conda_prefix=/opt/conda
WORKDIR / ARG conda_env_file=environment.yml
RUN git clone https://github.com/TencentARC/GFPGAN.git COPY --from=get_miniconda ${conda_prefix} ${conda_prefix}
RUN source ${conda_prefix}/etc/profile.d/conda.sh \
&& conda init bash \
&& source ~/.bashrc \
&& conda env create \
--name ${project_name} \
--file ${conda_env_file} \
&& rm -Rf ~/.cache \
&& conda clean -afy \
&& echo "conda activate ${project_name}" >> ~/.bashrc \
&& ln -s /data/models/GFPGANv1.4.pth ./src/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth \
&& conda activate ${project_name} \
&& python scripts/preload_models.py
WORKDIR /GFPGAN # Copy entrypoint and set env
RUN pip3 install -r requirements.txt \ ENV CONDA_PREFIX=${conda_prefix}
&& python3 setup.py develop \ ENV PROJECT_NAME=${project_name}
&& ln -s "/data/GFPGANv1.4.pth" experiments/pretrained_models/GFPGANv1.4.pth COPY docker-build/entrypoint.sh /
ENTRYPOINT [ "/entrypoint.sh" ]
WORKDIR /stable-diffusion
RUN python3 scripts/preload_models.py
WORKDIR /
COPY entrypoint.sh .
ENTRYPOINT ["/entrypoint.sh"]

81
docker-build/build.sh Executable file
View File

@ -0,0 +1,81 @@
#!/usr/bin/env bash
set -e
# IMPORTANT: You need to have a token on huggingface.co to be able to download the checkpoint!!!
# configure values by using env when executing build.sh
# f.e. env ARCH=aarch64 GITHUB_INVOKE_AI=https://github.com/yourname/yourfork.git ./build.sh
source ./docker-build/env.sh || echo "please run from repository root" || exit 1
invokeai_conda_version=${INVOKEAI_CONDA_VERSION:-py39_4.12.0-${platform/\//-}}
invokeai_conda_prefix=${INVOKEAI_CONDA_PREFIX:-\/opt\/conda}
invokeai_conda_env_file=${INVOKEAI_CONDA_ENV_FILE:-environment.yml}
invokeai_git=${INVOKEAI_GIT:-https://github.com/invoke-ai/InvokeAI.git}
huggingface_token=${HUGGINGFACE_TOKEN?}
# print the settings
echo "You are using these values:"
echo -e "project_name:\t\t ${project_name}"
echo -e "volumename:\t\t ${volumename}"
echo -e "arch:\t\t\t ${arch}"
echo -e "platform:\t\t ${platform}"
echo -e "invokeai_conda_version:\t ${invokeai_conda_version}"
echo -e "invokeai_conda_prefix:\t ${invokeai_conda_prefix}"
echo -e "invokeai_conda_env_file: ${invokeai_conda_env_file}"
echo -e "invokeai_git:\t\t ${invokeai_git}"
echo -e "invokeai_tag:\t\t ${invokeai_tag}\n"
_runAlpine() {
docker run \
--rm \
--interactive \
--tty \
--mount source="$volumename",target=/data \
--workdir /data \
alpine "$@"
}
_copyCheckpoints() {
echo "creating subfolders for models and outputs"
_runAlpine mkdir models
_runAlpine mkdir outputs
echo -n "downloading sd-v1-4.ckpt"
_runAlpine wget --header="Authorization: Bearer ${huggingface_token}" -O models/sd-v1-4.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
echo "done"
echo "downloading GFPGANv1.4.pth"
_runAlpine wget -O models/GFPGANv1.4.pth https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
}
_checkVolumeContent() {
_runAlpine ls -lhA /data/models
}
_getModelMd5s() {
_runAlpine \
alpine sh -c "md5sum /data/models/*"
}
if [[ -n "$(docker volume ls -f name="${volumename}" -q)" ]]; then
echo "Volume already exists"
if [[ -z "$(_checkVolumeContent)" ]]; then
echo "looks empty, copying checkpoint"
_copyCheckpoints
fi
echo "Models in ${volumename}:"
_checkVolumeContent
else
echo -n "createing docker volume "
docker volume create "${volumename}"
_copyCheckpoints
fi
# Build Container
docker build \
--platform="${platform}" \
--tag "${invokeai_tag}" \
--build-arg project_name="${project_name}" \
--build-arg conda_version="${invokeai_conda_version}" \
--build-arg conda_prefix="${invokeai_conda_prefix}" \
--build-arg conda_env_file="${invokeai_conda_env_file}" \
--build-arg invokeai_git="${invokeai_git}" \
--file ./docker-build/Dockerfile \
.

View File

@ -1,10 +1,8 @@
#!/bin/bash #!/bin/bash
set -e
cd /stable-diffusion source "${CONDA_PREFIX}/etc/profile.d/conda.sh"
conda activate "${PROJECT_NAME}"
if [ $# -eq 0 ]; then python scripts/invoke.py \
python3 scripts/dream.py --full_precision -o /data ${@:---web --host=0.0.0.0}
# bash
else
python3 scripts/dream.py --full_precision -o /data "$@"
fi

13
docker-build/env.sh Normal file
View File

@ -0,0 +1,13 @@
#!/usr/bin/env bash
project_name=${PROJECT_NAME:-invokeai}
volumename=${VOLUMENAME:-${project_name}_data}
arch=${ARCH:-x86_64}
platform=${PLATFORM:-Linux/${arch}}
invokeai_tag=${INVOKEAI_TAG:-${project_name}-${arch}}
export project_name
export volumename
export arch
export platform
export invokeai_tag

15
docker-build/run.sh Executable file
View File

@ -0,0 +1,15 @@
#!/usr/bin/env bash
set -e
source ./docker-build/env.sh || echo "please run from repository root" || exit 1
docker run \
--interactive \
--tty \
--rm \
--platform "$platform" \
--name "$project_name" \
--hostname "$project_name" \
--mount source="$volumename",target=/data \
--publish 9090:9090 \
"$invokeai_tag" ${1:+$@}

Binary file not shown.

After

Width:  |  Height:  |  Size: 587 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 572 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 557 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 571 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 570 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 568 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 527 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 489 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 503 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 488 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 499 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 524 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 593 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 598 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 488 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 487 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 489 KiB

View File

@ -153,6 +153,7 @@ Here are the invoke> command that apply to txt2img:
| --cfg_scale <float>| -C<float> | 7.5 | How hard to try to match the prompt to the generated image; any number greater than 1.0 works, but the useful range is roughly 5.0 to 20.0 | | --cfg_scale <float>| -C<float> | 7.5 | How hard to try to match the prompt to the generated image; any number greater than 1.0 works, but the useful range is roughly 5.0 to 20.0 |
| --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.| | --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.|
| --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. | | --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. |
| --karras_max <int> | | 29 | When using k_* samplers, set the maximum number of steps before shifting from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts) This value is sticky. [29] |
| --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution | | --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution |
| --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt | | --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt |
@ -218,8 +219,13 @@ well as the --mask (-M) and --text_mask (-tm) arguments:
| Argument <img width="100" align="right"/> | Shortcut | Default | Description | | Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|--------------------|------------|---------------------|--------------| |--------------------|------------|---------------------|--------------|
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.| | `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
| `--invert_mask ` | | False |If true, invert the mask so that transparent areas are opaque and vice versa.|
| `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image| | `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image|
The mask may either be an image with transparent areas, in which case
the inpainting will occur in the transparent areas only, or a black
and white image, in which case all black areas will be painted into.
`--text_mask` (short form `-tm`) is a way to generate a mask using a `--text_mask` (short form `-tm`) is a way to generate a mask using a
text description of the part of the image to replace. For example, if text description of the part of the image to replace. For example, if
you have an image of a breakfast plate with a bagel, toast and you have an image of a breakfast plate with a bagel, toast and

View File

@ -121,8 +121,6 @@ Both of the outputs look kind of like what I was thinking of. With the strength
If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `"fire"`: If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `"fire"`:
If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `fire`:
```commandline ```commandline
invoke> "fire" -s10 -W384 -H384 -S1592514025 -I /tmp/fire-drawing.png --strength 0.7 invoke> "fire" -s10 -W384 -H384 -S1592514025 -I /tmp/fire-drawing.png --strength 0.7
``` ```

View File

@ -34,6 +34,16 @@ original unedited image and the masked (partially transparent) image:
invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png
``` ```
If you are using Photoshop to make your transparent masks, here is a
protocol contributed by III_Communication36 (Discord name):
Create your alpha channel for mask in photoshop, then run
image/adjust/threshold on that channel. Export as Save a copy using
superpng (3rd party free download plugin) making sure alpha channel
is selected. Then masking works as it should for the img2img
process 100%. Can feed just one image this way without needing to
feed the -M mask behind it
## **Masking using Text** ## **Masking using Text**
You can also create a mask using a text prompt to select the part of You can also create a mask using a text prompt to select the part of
@ -139,7 +149,83 @@ region directly:
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20 invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
``` ```
### Inpainting is not changing the masked region enough! ## Using the RunwayML inpainting model
The [RunwayML Inpainting Model
v1.5](https://huggingface.co/runwayml/stable-diffusion-inpainting) is
a specialized version of [Stable Diffusion
v1.5](https://huggingface.co/spaces/runwayml/stable-diffusion-v1-5)
that contains extra channels specifically designed to enhance
inpainting and outpainting. While it can do regular `txt2img` and
`img2img`, it really shines when filling in missing regions. It has an
almost uncanny ability to blend the new regions with existing ones in
a semantically coherent way.
To install the inpainting model, follow the
[instructions](INSTALLING-MODELS.md) for installing a new model. You
may use either the CLI (`invoke.py` script) or directly edit the
`configs/models.yaml` configuration file to do this. The main thing to
watch out for is that the the model `config` option must be set up to
use `v1-inpainting-inference.yaml` rather than the `v1-inference.yaml`
file that is used by Stable Diffusion 1.4 and 1.5.
After installation, your `models.yaml` should contain an entry that
looks like this one:
inpainting-1.5:
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
description: SD inpainting v1.5
config: configs/stable-diffusion/v1-inpainting-inference.yaml
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
As shown in the example, you may include a VAE fine-tuning weights
file as well. This is strongly recommended.
To use the custom inpainting model, launch `invoke.py` with the
argument `--model inpainting-1.5` or alternatively from within the
script use the `!switch inpainting-1.5` command to load and switch to
the inpainting model.
You can now do inpainting and outpainting exactly as described above,
but there will (likely) be a noticeable improvement in
coherence. Txt2img and Img2img will work as well.
There are a few caveats to be aware of:
1. The inpainting model is larger than the standard model, and will
use nearly 4 GB of GPU VRAM. This makes it unlikely to run on
a 4 GB graphics card.
2. When operating in Img2img mode, the inpainting model is much less
steerable than the standard model. It is great for making small
changes, such as changing the pattern of a fabric, or slightly
changing a subject's expression or hair, but the model will
resist making the dramatic alterations that the standard
model lets you do.
3. While the `--hires` option works fine with the inpainting model,
some special features, such as `--embiggen` are disabled.
4. Prompt weighting (`banana++ sushi`) and merging work well with
the inpainting model, but prompt swapping (a ("fluffy cat").swap("smiling dog") eating a hotdog`)
will not have any effect due to the way the model is set up.
You may use text masking (with `-tm thing-to-mask`) as an
effective replacement.
5. The model tends to oversharpen image if you use high step or CFG
values. If you need to do large steps, use the standard model.
6. The `--strength` (`-f`) option has no effect on the inpainting
model due to its fundamental differences with the standard
model. It will always take the full number of steps you specify.
## Troubleshooting
Here are some troubleshooting tips for inpainting and outpainting.
## Inpainting is not changing the masked region enough!
One of the things to understand about how inpainting works is that it One of the things to understand about how inpainting works is that it
is equivalent to running img2img on just the masked (transparent) is equivalent to running img2img on just the masked (transparent)

View File

@ -15,13 +15,52 @@ InvokeAI supports two versions of outpainting, one called "outpaint"
and the other "outcrop." They work slightly differently and each has and the other "outcrop." They work slightly differently and each has
its advantages and drawbacks. its advantages and drawbacks.
### Outpainting
Outpainting is the same as inpainting, except that the painting occurs
in the regions outside of the original image. To outpaint using the
`invoke.py` command line script, prepare an image in which the borders
to be extended are pure black. Add an alpha channel (if there isn't one
already), and make the borders completely transparent and the interior
completely opaque. If you wish to modify the interior as well, you may
create transparent holes in the transparency layer, which `img2img` will
paint into as usual.
Pass the image as the argument to the `-I` switch as you would for
regular inpainting:
invoke> a stream by a river -I /path/to/transparent_img.png
You'll likely be delighted by the results.
### Tips
1. Do not try to expand the image too much at once. Generally it is best
to expand the margins in 64-pixel increments. 128 pixels often works,
but your mileage may vary depending on the nature of the image you are
trying to outpaint into.
2. There are a series of switches that can be used to adjust how the
inpainting algorithm operates. In particular, you can use these to
minimize the seam that sometimes appears between the original image
and the extended part. These switches are:
--seam_size SEAM_SIZE Size of the mask around the seam between original and outpainted image (0)
--seam_blur SEAM_BLUR The amount to blur the seam inwards (0)
--seam_strength STRENGTH The img2img strength to use when filling the seam (0.7)
--seam_steps SEAM_STEPS The number of steps to use to fill the seam. (10)
--tile_size TILE_SIZE The tile size to use for filling outpaint areas (32)
### Outcrop ### Outcrop
The `outcrop` extension allows you to extend the image in 64 pixel The `outcrop` extension gives you a convenient `!fix` postprocessing
increments in any dimension. You can apply the module to any image command that allows you to extend a previously-generated image in 64
previously-generated by InvokeAI. Note that it will **not** work with pixel increments in any direction. You can apply the module to any
arbitrary photographs or Stable Diffusion images created by other image previously-generated by InvokeAI. Note that it works with
implementations. arbitrary PNG photographs, but not currently with JPG or other
formats. Outcropping is particularly effective when combined with the
[runwayML custom inpainting
model](INPAINTING.md#using-the-runwayml-inpainting-model).
Consider this image: Consider this image:
@ -64,42 +103,3 @@ you'll get a slightly different result. You can run it repeatedly
until you get an image you like. Unfortunately `!fix` does not until you get an image you like. Unfortunately `!fix` does not
currently respect the `-n` (`--iterations`) argument. currently respect the `-n` (`--iterations`) argument.
## Outpaint
The `outpaint` extension does the same thing, but with subtle
differences. Starting with the same image, here is how we would add an
additional 64 pixels to the top of the image:
```bash
invoke> !fix images/curly.png --out_direction top 64
```
(you can abbreviate `--out_direction` as `-D`.
The result is shown here:
<div align="center" markdown>
![curly_woman_outpaint](../assets/outpainting/curly-outpaint.png)
</div>
Although the effect is similar, there are significant differences from
outcropping:
- You can only specify one direction to extend at a time.
- The image is **not** resized. Instead, the image is shifted by the specified
number of pixels. If you look carefully, you'll see that less of the lady's
torso is visible in the image.
- Because the image dimensions remain the same, there's no rounding
to multiples of 64.
- Attempting to outpaint larger areas will frequently give rise to ugly
ghosting effects.
- For best results, try increasing the step number.
- If you don't specify a pixel value in `-D`, it will default to half
of the whole image, which is likely not what you want.
!!! tip
Neither `outpaint` nor `outcrop` are perfect, but we continue to tune
and improve them. If one doesn't work, try the other. You may also
wish to experiment with other `img2img` arguments, such as `-C`, `-f`
and `-s`.

View File

@ -45,7 +45,7 @@ Here's a prompt that depicts what it does.
original prompt: original prompt:
`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180` `#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
<div align="center" markdown> <div align="center" markdown>
![step1](../assets/negative_prompt_walkthru/step1.png) ![step1](../assets/negative_prompt_walkthru/step1.png)
@ -84,6 +84,109 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h
--- ---
## **Prompt Syntax Features**
The InvokeAI prompting language has the following features:
### Attention weighting
Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term).
The following syntax is recognised:
* single words without parentheses: `a tall thin man picking apricots+`
* single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-`
* more effect with more symbols `a tall thin man (picking apricots)++`
* nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`)
* all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.)
* attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder
You can use this to increase or decrease the amount of something. Starting from this prompt of `a man picking apricots from a tree`, let's see what happens if we increase and decrease how much attention we want Stable Diffusion to pay to the word `apricots`:
![an AI generated image of a man picking apricots from a tree](../assets/prompt_syntax/apricots-0.png)
Using `-` to reduce apricot-ness:
| `a man picking apricots- from a tree` | `a man picking apricots-- from a tree` | `a man picking apricots--- from a tree` |
| -- | -- | -- |
| ![an AI generated image of a man picking apricots from a tree, with smaller apricots](../assets/prompt_syntax/apricots--1.png) | ![an AI generated image of a man picking apricots from a tree, with even smaller and fewer apricots](../assets/prompt_syntax/apricots--2.png) | ![an AI generated image of a man picking apricots from a tree, with very few very small apricots](../assets/prompt_syntax/apricots--3.png) |
Using `+` to increase apricot-ness:
| `a man picking apricots+ from a tree` | `a man picking apricots++ from a tree` | `a man picking apricots+++ from a tree` | `a man picking apricots++++ from a tree` | `a man picking apricots+++++ from a tree` |
| -- | -- | -- | -- | -- |
| ![an AI generated image of a man picking apricots from a tree, with larger, more vibrant apricots](../assets/prompt_syntax/apricots-1.png) | ![an AI generated image of a man picking apricots from a tree with even larger, even more vibrant apricots](../assets/prompt_syntax/apricots-2.png) | ![an AI generated image of a man picking apricots from a tree, but the man has been replaced by a pile of apricots](../assets/prompt_syntax/apricots-3.png) | ![an AI generated image of a man picking apricots from a tree, but the man has been replaced by a mound of giant melting-looking apricots](../assets/prompt_syntax/apricots-4.png) | ![an AI generated image of a man picking apricots from a tree, but the man and the leaves and parts of the ground have all been replaced by giant melting-looking apricots](../assets/prompt_syntax/apricots-5.png) |
You can also change the balance between different parts of a prompt. For example, below is a `mountain man`:
![an AI generated image of a mountain man](../assets/prompt_syntax/mountain-man.png)
And here he is with more mountain:
| `mountain+ man` | `mountain++ man` | `mountain+++ man` |
| -- | -- | -- |
| ![](../assets/prompt_syntax/mountain1-man.png) | ![](../assets/prompt_syntax/mountain2-man.png) | ![](../assets/prompt_syntax/mountain3-man.png) |
Or, alternatively, with more man:
| `mountain man+` | `mountain man++` | `mountain man+++` | `mountain man++++` |
| -- | -- | -- | -- |
| ![](../assets/prompt_syntax/mountain-man1.png) | ![](../assets/prompt_syntax/mountain-man2.png) | ![](../assets/prompt_syntax/mountain-man3.png) | ![](../assets/prompt_syntax/mountain-man4.png) |
### Blending between prompts
* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)`
* The existing prompt blending using `:<weight>` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax.
* Attention weights can be nested inside blends.
* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output.
See the section below on "Prompt Blending" for more information about how this works.
### Cross-Attention Control ('prompt2prompt')
Sometimes an image you generate is almost right, and you just want to
change one detail without affecting the rest. You could use a photo editor and inpainting
to overpaint the area, but that's a pain. Here's where `prompt2prompt`
comes in handy.
Generate an image with a given prompt, record the seed of the image,
and then use the `prompt2prompt` syntax to substitute words in the
original prompt for words in a new prompt. This works for `img2img` as well.
* `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
* for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`.
* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand.
* Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form.
* The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image).
* For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect.
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
The `prompt2prompt` code is based off [bloc97's
colab](https://github.com/bloc97/CrossAttentionControl).
Note that `prompt2prompt` is not currently working with the runwayML
inpainting model, and may never work due to the way this model is set
up. If you attempt to use `prompt2prompt` you will get the original
image back. However, since this model is so good at inpainting, a
good substitute is to use the `clipseg` text masking option:
```
invoke> a fluffy cat eating a hotdot
Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
```
### Escaping parantheses () and speech marks ""
If the model you are using has parentheses () or speech marks "" as
part of its syntax, you will need to "escape" these using a backslash,
so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt
parser will attempt to interpret the parentheses as part of the prompt
syntax and it will get confused.
## **Prompt Blending** ## **Prompt Blending**
You may blend together different sections of the prompt to explore the You may blend together different sections of the prompt to explore the

View File

@ -36,20 +36,6 @@ another environment with NVIDIA GPUs on-premises or in the cloud.
### Prerequisites ### Prerequisites
#### Get the data files
Go to
[Hugging Face](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original),
and click "Access repository" to Download the model file `sd-v1-4.ckpt` (~4 GB)
to `~/Downloads`. You'll need to create an account but it's quick and free.
Also download the face restoration model.
```Shell
cd ~/Downloads
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
```
#### Install [Docker](https://github.com/santisbon/guides#docker) #### Install [Docker](https://github.com/santisbon/guides#docker)
On the Docker Desktop app, go to Preferences, Resources, Advanced. Increase the On the Docker Desktop app, go to Preferences, Resources, Advanced. Increase the
@ -57,86 +43,61 @@ CPUs and Memory to avoid this
[Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to [Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to
increase Swap and Disk image size too. increase Swap and Disk image size too.
#### Get a Huggingface-Token
Go to [Hugging Face](https://huggingface.co/settings/tokens), create a token and
temporary place it somewhere like a open texteditor window (but dont save it!,
only keep it open, we need it in the next step)
### Setup ### Setup
Set the fork you want to use and other variables. Set the fork you want to use and other variables.
```Shell !!! tip
TAG_STABLE_DIFFUSION="santisbon/stable-diffusion"
PLATFORM="linux/arm64"
GITHUB_STABLE_DIFFUSION="-b orig-gfpgan https://github.com/santisbon/stable-diffusion.git"
REQS_STABLE_DIFFUSION="requirements-linux-arm64.txt"
CONDA_SUBDIR="osx-arm64"
echo $TAG_STABLE_DIFFUSION I preffer to save my env vars
echo $PLATFORM in the repository root in a `.env` (or `.envrc`) file to automatically re-apply
echo $GITHUB_STABLE_DIFFUSION them when I come back.
echo $REQS_STABLE_DIFFUSION
echo $CONDA_SUBDIR The build- and run- scripts contain default values for almost everything,
besides the [Hugging Face Token](https://huggingface.co/settings/tokens) you
created in the last step.
Some Suggestions of variables you may want to change besides the Token:
| Environment-Variable | Description |
| ------------------------------------------------------------------- | ------------------------------------------------------------------------ |
| `HUGGINGFACE_TOKEN="hg_aewirhghlawrgkjbarug2"` | This is the only required variable, without you can't get the checkpoint |
| `ARCH=aarch64` | if you are using a ARM based CPU |
| `INVOKEAI_TAG=yourname/invokeai:latest` | the Container Repository / Tag which will be used |
| `INVOKEAI_CONDA_ENV_FILE=environment-linux-aarch64.yml` | since environment.yml wouldn't work with aarch |
| `INVOKEAI_GIT="-b branchname https://github.com/username/reponame"` | if you want to use your own fork |
#### Build the Image
I provided a build script, which is located in `docker-build/build.sh` but still
needs to be executed from the Repository root.
```bash
docker-build/build.sh
``` ```
Create a Docker volume for the downloaded model files. The build Script not only builds the container, but also creates the docker
volume if not existing yet, or if empty it will just download the models. When
it is done you can run the container via the run script
```Shell ```bash
docker volume create my-vol docker-build/run.sh
``` ```
Copy the data files to the Docker volume using a lightweight Linux container. When used without arguments, the container will start the website and provide
We'll need the models at run time. You just need to create the container with you the link to open it. But if you want to use some other parameters you can
the mountpoint; no need to run this dummy container. also do so.
```Shell !!! warning "Deprecated"
cd ~/Downloads # or wherever you saved the files
docker create --platform $PLATFORM --name dummy --mount source=my-vol,target=/data alpine From here on it is the rest of the previous Docker-Docs, which will still
provide usefull informations for one or the other.
docker cp sd-v1-4.ckpt dummy:/data
docker cp GFPGANv1.4.pth dummy:/data
```
Get the repo and download the Miniconda installer (we'll need it at build time).
Replace the URL with the version matching your container OS and the architecture
it will run on.
```Shell
cd ~
git clone $GITHUB_STABLE_DIFFUSION
cd stable-diffusion/docker-build
chmod +x entrypoint.sh
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O anaconda.sh && chmod +x anaconda.sh
```
Build the Docker image. Give it any tag `-t` that you want.
Choose the Linux container's host platform: x86-64/Intel is `amd64`. Apple
silicon is `arm64`. If deploying the container to the cloud to leverage powerful
GPU instances you'll be on amd64 hardware but if you're just trying this out
locally on Apple silicon choose arm64.
The application uses libraries that need to match the host environment so use
the appropriate requirements file.
Tip: Check that your shell session has the env variables set above.
```Shell
docker build -t $TAG_STABLE_DIFFUSION \
--platform $PLATFORM \
--build-arg gsd=$GITHUB_STABLE_DIFFUSION \
--build-arg rsd=$REQS_STABLE_DIFFUSION \
--build-arg cs=$CONDA_SUBDIR \
.
```
Run a container using your built image.
Tip: Make sure you've created and populated the Docker volume (above).
```Shell
docker run -it \
--rm \
--platform $PLATFORM \
--name stable-diffusion \
--hostname stable-diffusion \
--mount source=my-vol,target=/data \
$TAG_STABLE_DIFFUSION
```
## Usage (time to have fun) ## Usage (time to have fun)
@ -240,7 +201,8 @@ server with:
python3 scripts/invoke.py --full_precision --web python3 scripts/invoke.py --full_precision --web
``` ```
If it's running on your Mac point your Mac web browser to http://127.0.0.1:9090 If it's running on your Mac point your Mac web browser to
<http://127.0.0.1:9090>
Press Control-C at the command line to stop the web server. Press Control-C at the command line to stop the web server.

View File

@ -0,0 +1,44 @@
name: invokeai
channels:
- pytorch
- conda-forge
dependencies:
- python>=3.9
- pip>=20.3
- cudatoolkit
- pytorch
- torchvision
- numpy=1.19
- imageio=2.9.0
- opencv=4.6.0
- pillow=8.*
- flask=2.1.*
- flask_cors=3.0.10
- flask-socketio=5.3.0
- send2trash=1.8.0
- eventlet
- albumentations=0.4.3
- pudb=2019.2
- imageio-ffmpeg=0.4.2
- pytorch-lightning=1.7.7
- streamlit
- einops=0.3.0
- kornia=0.6
- torchmetrics=0.7.0
- transformers=4.21.3
- torch-fidelity=0.3.0
- tokenizers>=0.11.1,!=0.11.3,<0.13
- pip:
- omegaconf==2.1.1
- realesrgan==0.2.5.0
- test-tube>=0.7.5
- pyreadline3
- dependency_injector==4.40.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
- -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
- -e .
variables:
PYTORCH_ENABLE_MPS_FALLBACK: 1

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import pyparsing
# Derived from source code carrying the following copyrights # Derived from source code carrying the following copyrights
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
@ -24,6 +24,7 @@ from PIL import Image, ImageOps
from torch import nn from torch import nn
from pytorch_lightning import seed_everything, logging from pytorch_lightning import seed_everything, logging
from ldm.invoke.prompt_parser import PromptParser
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.args import metadata_from_png from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
@ -179,6 +180,7 @@ class Generate:
self.size_matters = True # used to warn once about large image sizes and VRAM self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None self.txt2mask = None
self.safety_checker = None self.safety_checker = None
self.karras_max = None
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
@ -269,10 +271,12 @@ class Generate:
variation_amount = 0.0, variation_amount = 0.0,
threshold = 0.0, threshold = 0.0,
perlin = 0.0, perlin = 0.0,
karras_max = None,
# these are specific to img2img and inpaint # these are specific to img2img and inpaint
init_img = None, init_img = None,
init_mask = None, init_mask = None,
text_mask = None, text_mask = None,
invert_mask = False,
fit = False, fit = False,
strength = None, strength = None,
init_color = None, init_color = None,
@ -293,6 +297,13 @@ class Generate:
catch_interrupts = False, catch_interrupts = False,
hires_fix = False, hires_fix = False,
use_mps_noise = False, use_mps_noise = False,
# Seam settings for outpainting
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
force_outpaint: bool = False,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -310,6 +321,7 @@ class Generate:
init_img // path to an initial image init_img // path to an initial image
init_mask // path to a mask for the initial image init_mask // path to a mask for the initial image
text_mask // a text string that will be used to guide clipseg generation of the init_mask text_mask // a text string that will be used to guide clipseg generation of the init_mask
invert_mask // boolean, if true invert the mask
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
@ -351,6 +363,7 @@ class Generate:
self.seed = seed self.seed = seed
self.log_tokenization = log_tokenization self.log_tokenization = log_tokenization
self.step_callback = step_callback self.step_callback = step_callback
self.karras_max = karras_max
with_variations = [] if with_variations is None else with_variations with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache # will instantiate the model or return it from cache
@ -395,6 +408,11 @@ class Generate:
self.sampler_name = sampler_name self.sampler_name = sampler_name
self._set_sampler() self._set_sampler()
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for
if karras_max is not None and isinstance(self.sampler,KSampler):
self.sampler.adjust_settings(karras_max=karras_max)
tic = time.time() tic = time.time()
if self._has_cuda(): if self._has_cuda():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
@ -404,7 +422,7 @@ class Generate:
mask_image = None mask_image = None
try: try:
uc, c = get_uc_and_c( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
@ -417,19 +435,12 @@ class Generate:
height, height,
fit=fit, fit=fit,
text_mask=text_mask, text_mask=text_mask,
invert_mask=invert_mask,
force_outpaint=force_outpaint,
) )
# TODO: Hacky selection of operation to perform. Needs to be refactored. # TODO: Hacky selection of operation to perform. Needs to be refactored.
if (init_image is not None) and (mask_image is not None): generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None):
generator = self._make_embiggen()
elif init_image is not None:
generator = self._make_img2img()
elif hires_fix:
generator = self._make_txt2img2img()
else:
generator = self._make_txt2img()
generator.set_variation( generator.set_variation(
self.seed, variation_amount, with_variations self.seed, variation_amount, with_variations
@ -448,7 +459,7 @@ class Generate:
sampler=self.sampler, sampler=self.sampler,
steps=steps, steps=steps,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
conditioning=(uc, c), conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated step_callback=step_callback, # called after each intermediate image is generated
@ -464,7 +475,13 @@ class Generate:
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius, mask_blur_radius=mask_blur_radius,
safety_checker=checker safety_checker=checker,
seam_size = seam_size,
seam_blur = seam_blur,
seam_strength = seam_strength,
seam_steps = seam_steps,
tile_size = tile_size,
force_outpaint = force_outpaint
) )
if init_color: if init_color:
@ -481,14 +498,14 @@ class Generate:
save_original = save_original, save_original = save_original,
image_callback = image_callback) image_callback = image_callback)
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
except KeyboardInterrupt: except KeyboardInterrupt:
if catch_interrupts: if catch_interrupts:
print('**Interrupted** Partial results will be returned.') print('**Interrupted** Partial results will be returned.')
else: else:
raise KeyboardInterrupt raise KeyboardInterrupt
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
toc = time.time() toc = time.time()
print('>> Usage stats:') print('>> Usage stats:')
@ -545,7 +562,7 @@ class Generate:
# try to reuse the same filename prefix as the original file. # try to reuse the same filename prefix as the original file.
# we take everything up to the first period # we take everything up to the first period
prefix = None prefix = None
m = re.match('^([^.]+)\.',os.path.basename(image_path)) m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
if m: if m:
prefix = m.groups()[0] prefix = m.groups()[0]
@ -553,7 +570,8 @@ class Generate:
image = Image.open(image_path) image = Image.open(image_path)
# used by multiple postfixers # used by multiple postfixers
uc, c = get_uc_and_c( # todo: cross-attention control
uc, c, _ = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=opt.skip_normalize, skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization log_tokens =opt.log_tokenization
@ -598,10 +616,9 @@ class Generate:
elif tool == 'embiggen': elif tool == 'embiggen':
# fetch the metadata from the image # fetch the metadata from the image
generator = self._make_embiggen() generator = self.select_generator(embiggen=True)
opt.strength = 0.40 opt.strength = 0.40
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening') print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
# embiggen takes a image path (sigh)
generator.generate( generator.generate(
prompt, prompt,
sampler = self.sampler, sampler = self.sampler,
@ -635,6 +652,32 @@ class Generate:
print(f'* postprocessing tool {tool} is not yet supported') print(f'* postprocessing tool {tool} is not yet supported')
return None return None
def select_generator(
self,
init_image:Image.Image=None,
mask_image:Image.Image=None,
embiggen:bool=False,
hires_fix:bool=False,
force_outpaint:bool=False,
):
inpainting_model_in_use = self.sampler.uses_inpainting_model()
if hires_fix:
return self._make_txt2img2img()
if embiggen is not None:
return self._make_embiggen()
if inpainting_model_in_use:
return self._make_omnibus()
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
return self._make_inpaint()
if init_image is not None:
return self._make_img2img()
return self._make_txt2img()
def _make_images( def _make_images(
self, self,
@ -644,6 +687,8 @@ class Generate:
height, height,
fit=False, fit=False,
text_mask=None, text_mask=None,
invert_mask=False,
force_outpaint=False,
): ):
init_image = None init_image = None
init_mask = None init_mask = None
@ -657,7 +702,7 @@ class Generate:
# if image has a transparent area and no mask was provided, then try to generate mask # if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image): if self._has_transparency(image):
self._transparency_check_and_warning(image, mask) self._transparency_check_and_warning(image, mask, force_outpaint)
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters: if (image.width * image.height) > (self.width * self.height) and self.size_matters:
@ -673,8 +718,12 @@ class Generate:
elif text_mask: elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
if invert_mask:
init_mask = ImageOps.invert(init_mask)
return init_image,init_mask return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
from ldm.invoke.generator import Generator from ldm.invoke.generator import Generator
@ -685,6 +734,7 @@ class Generate:
if not self.generators.get('img2img'): if not self.generators.get('img2img'):
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model, self.precision) self.generators['img2img'] = Img2Img(self.model, self.precision)
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['img2img'] return self.generators['img2img']
def _make_embiggen(self): def _make_embiggen(self):
@ -713,6 +763,15 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision) self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint'] return self.generators['inpaint']
# "omnibus" supports the runwayML custom inpainting model, which does
# txt2img, img2img and inpainting using slight variations on the same code
def _make_omnibus(self):
if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision)
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
return self.generators['omnibus']
def load_model(self): def load_model(self):
''' '''
preload model identified in self.model_name preload model identified in self.model_name
@ -839,6 +898,8 @@ class Generate:
def sample_to_image(self, samples): def sample_to_image(self, samples):
return self._make_base().sample_to_image(samples) return self._make_base().sample_to_image(samples)
# very repetitive code - can this be simplified? The KSampler names are
# consistent, at least
def _set_sampler(self): def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}' msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms': if self.sampler_name == 'plms':
@ -846,15 +907,11 @@ class Generate:
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device) self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a': elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler( self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2': elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device) self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a': elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler( self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler': elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device) self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun': elif self.sampler_name == 'k_heun':
@ -888,7 +945,8 @@ class Generate:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
return image return image
def _create_init_image(self, image, width, height, fit=True): def _create_init_image(self, image: Image.Image, width, height, fit=True):
if image.mode != 'RGBA':
image = image.convert('RGB') image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image return image
@ -954,11 +1012,11 @@ class Generate:
colored += 1 colored += 1
return colored == 0 return colored == 0
def _transparency_check_and_warning(self,image, mask): def _transparency_check_and_warning(self,image, mask, force_outpaint=False):
if not mask: if not mask:
print( print(
'>> Initial image has transparent areas. Will inpaint in these regions.') '>> Initial image has transparent areas. Will inpaint in these regions.')
if self._check_for_erasure(image): if (not force_outpaint) and self._check_for_erasure(image):
print( print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n', '>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n', '>> Inpainting will be suboptimal. Please preserve the colors when making\n',

View File

@ -83,16 +83,16 @@ with metadata_from_png():
import argparse import argparse
from argparse import Namespace, RawTextHelpFormatter from argparse import Namespace, RawTextHelpFormatter
import pydoc import pydoc
import shlex
import json import json
import hashlib import hashlib
import os import os
import re import re
import shlex
import copy import copy
import base64 import base64
import functools import functools
import ldm.invoke.pngwriter import ldm.invoke.pngwriter
from ldm.invoke.conditioning import split_weighted_subprompts from ldm.invoke.prompt_parser import split_weighted_subprompts
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
'ddim', 'ddim',
@ -169,28 +169,31 @@ class Args(object):
def parse_cmd(self,cmd_string): def parse_cmd(self,cmd_string):
'''Parse a invoke>-style command string ''' '''Parse a invoke>-style command string '''
command = cmd_string.replace("'", "\\'") # handle the case in which the first token is a switch
try: if cmd_string.startswith('-'):
elements = shlex.split(command) prompt = ''
elements = [x.replace("\\'","'") for x in elements] switches = cmd_string
except ValueError: # handle the case in which the prompt is enclosed by quotes
import sys, traceback elif cmd_string.startswith('"'):
print(traceback.format_exc(), file=sys.stderr) a = shlex.split(cmd_string,comments=True)
return prompt = a[0]
switches = [''] switches = shlex.join(a[1:])
switches_started = False
for element in elements:
if element[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(element)
else: else:
switches[0] += element # no initial quote, so get everything up to the first thing
switches[0] += ' ' # that looks like a switch
switches[0] = switches[0][: len(switches[0]) - 1] if cmd_string.startswith('-'):
prompt = ''
switches = cmd_string
else:
match = re.match('^(.+?)\s(--?[a-zA-Z].+)',cmd_string)
if match:
prompt,switches = match.groups()
else:
prompt = cmd_string
switches = ''
try: try:
self._cmd_switches = self._cmd_parser.parse_args(switches) self._cmd_switches = self._cmd_parser.parse_args(shlex.split(switches,comments=True))
setattr(self._cmd_switches,'prompt',prompt)
return self._cmd_switches return self._cmd_switches
except: except:
return None return None
@ -211,13 +214,16 @@ class Args(object):
a = vars(self) a = vars(self)
a.update(kwargs) a.update(kwargs)
switches = list() switches = list()
switches.append(f'"{a["prompt"]}"') prompt = a['prompt']
prompt.replace('"','\\"')
switches.append(prompt)
switches.append(f'-s {a["steps"]}') switches.append(f'-s {a["steps"]}')
switches.append(f'-S {a["seed"]}') switches.append(f'-S {a["seed"]}')
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}') switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}') switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'--fnformat {a["fnformat"]}') if a['karras_max'] is not None:
switches.append(f'--karras_max {a["karras_max"]}')
if a['perlin'] > 0: if a['perlin'] > 0:
switches.append(f'--perlin {a["perlin"]}') switches.append(f'--perlin {a["perlin"]}')
if a['threshold'] > 0: if a['threshold'] > 0:
@ -243,6 +249,8 @@ class Args(object):
switches.append(f'-f {a["strength"]}') switches.append(f'-f {a["strength"]}')
if a['inpaint_replace']: if a['inpaint_replace']:
switches.append(f'--inpaint_replace') switches.append(f'--inpaint_replace')
if a['text_mask']:
switches.append(f'-tm {" ".join([str(u) for u in a["text_mask"]])}')
else: else:
switches.append(f'-A {a["sampler_name"]}') switches.append(f'-A {a["sampler_name"]}')
@ -567,10 +575,17 @@ class Args(object):
) )
render_group = parser.add_argument_group('General rendering') render_group = parser.add_argument_group('General rendering')
img2img_group = parser.add_argument_group('Image-to-image and inpainting') img2img_group = parser.add_argument_group('Image-to-image and inpainting')
inpainting_group = parser.add_argument_group('Inpainting')
outpainting_group = parser.add_argument_group('Outpainting and outcropping')
variation_group = parser.add_argument_group('Creating and combining variations') variation_group = parser.add_argument_group('Creating and combining variations')
postprocessing_group = parser.add_argument_group('Post-processing') postprocessing_group = parser.add_argument_group('Post-processing')
special_effects_group = parser.add_argument_group('Special effects') special_effects_group = parser.add_argument_group('Special effects')
render_group.add_argument('prompt') deprecated_group = parser.add_argument_group('Deprecated options')
render_group.add_argument(
'--prompt',
default='',
help='prompt string',
)
render_group.add_argument( render_group.add_argument(
'-s', '-s',
'--steps', '--steps',
@ -688,7 +703,13 @@ class Args(object):
default=6, default=6,
choices=range(0,10), choices=range(0,10),
dest='png_compression', dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.' help='level of PNG compression, from 0 (none) to 9 (maximum). [6]'
)
render_group.add_argument(
'--karras_max',
type=int,
default=None,
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
) )
img2img_group.add_argument( img2img_group.add_argument(
'-I', '-I',
@ -696,12 +717,6 @@ class Args(object):
type=str, type=str,
help='Path to input image for img2img mode (supersedes width and height)', help='Path to input image for img2img mode (supersedes width and height)',
) )
img2img_group.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument( img2img_group.add_argument(
'-tm', '-tm',
'--text_mask', '--text_mask',
@ -729,29 +744,68 @@ class Args(object):
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
default=0.75, default=0.75,
) )
img2img_group.add_argument( inpainting_group.add_argument(
'-D', '-M',
'--out_direction', '--init_mask',
nargs='+',
type=str, type=str,
metavar=('direction', 'pixels'), help='Path to input mask for inpainting mode (supersedes width and height)',
help='Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
) )
img2img_group.add_argument( inpainting_group.add_argument(
'-c', '--invert_mask',
'--outcrop', action='store_true',
nargs='+', help='Invert the mask',
type=str,
metavar=('direction','pixels'),
help='Outcrop the image with one or more direction/pixel pairs: -c top 64 bottom 128 left 64 right 64',
) )
img2img_group.add_argument( inpainting_group.add_argument(
'-r', '-r',
'--inpaint_replace', '--inpaint_replace',
type=float, type=float,
default=0.0, default=0.0,
help='when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)', help='when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)',
) )
outpainting_group.add_argument(
'-c',
'--outcrop',
nargs='+',
type=str,
metavar=('direction','pixels'),
help='Outcrop the image with one or more direction/pixel pairs: e.g. -c top 64 bottom 128 left 64 right 64',
)
outpainting_group.add_argument(
'--force_outpaint',
action='store_true',
default=False,
help='Force outpainting if you have no inpainting mask to pass',
)
outpainting_group.add_argument(
'--seam_size',
type=int,
default=0,
help='When outpainting, size of the mask around the seam between original and outpainted image',
)
outpainting_group.add_argument(
'--seam_blur',
type=int,
default=0,
help='When outpainting, the amount to blur the seam inwards',
)
outpainting_group.add_argument(
'--seam_strength',
type=float,
default=0.7,
help='When outpainting, the img2img strength to use when filling the seam. Values around 0.7 work well',
)
outpainting_group.add_argument(
'--seam_steps',
type=int,
default=10,
help='When outpainting, the number of steps to use to fill the seam. Low values (~10) work well',
)
outpainting_group.add_argument(
'--tile_size',
type=int,
default=32,
help='When outpainting, the tile size to use for filling outpaint areas',
)
postprocessing_group.add_argument( postprocessing_group.add_argument(
'-ft', '-ft',
'--facetool', '--facetool',
@ -835,7 +889,14 @@ class Args(object):
dest='use_mps_noise', dest='use_mps_noise',
help='Simulate noise on M1 systems to get the same results' help='Simulate noise on M1 systems to get the same results'
) )
deprecated_group.add_argument(
'-D',
'--out_direction',
nargs='+',
type=str,
metavar=('direction', 'pixels'),
help='Older outcropping system. Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
)
return parser return parser
def format_metadata(**kwargs): def format_metadata(**kwargs):
@ -871,7 +932,7 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266 # remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','fnformat', 'step_number','width','height','extra','strength', 'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
'init_img','init_mask','facetool','facetool_strength','upscale'] 'init_img','init_mask','facetool','facetool_strength','upscale']
rfc_dict ={} rfc_dict ={}
@ -922,6 +983,23 @@ def metadata_dumps(opt,
return metadata return metadata
@functools.lru_cache(maxsize=50)
def args_from_png(png_file_path) -> list[Args]:
'''
Given the path to a PNG file created by invoke.py,
retrieves a list of Args objects containing the image
data.
'''
try:
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
except AttributeError:
return [legacy_metadata_load({},png_file_path)]
try:
return metadata_loads(meta)
except:
return [legacy_metadata_load(meta,png_file_path)]
@functools.lru_cache(maxsize=50) @functools.lru_cache(maxsize=50)
def metadata_from_png(png_file_path) -> Args: def metadata_from_png(png_file_path) -> Args:
''' '''
@ -929,11 +1007,8 @@ def metadata_from_png(png_file_path) -> Args:
an Args object containing the image metadata. Note that this an Args object containing the image metadata. Note that this
returns a single Args object, not multiple. returns a single Args object, not multiple.
''' '''
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path) args_list = args_from_png(png_file_path)
if 'sd-metadata' in meta and len(meta['sd-metadata'])>0 : return args_list[0]
return metadata_loads(meta)[0]
else:
return legacy_metadata_load(meta,png_file_path)
def dream_cmd_from_png(png_file_path): def dream_cmd_from_png(png_file_path):
opt = metadata_from_png(png_file_path) opt = metadata_from_png(png_file_path)
@ -948,7 +1023,7 @@ def metadata_loads(metadata) -> list:
''' '''
results = [] results = []
try: try:
if 'grid' in metadata['sd-metadata']: if 'images' in metadata['sd-metadata']:
images = metadata['sd-metadata']['images'] images = metadata['sd-metadata']['images']
else: else:
images = [metadata['sd-metadata']['image']] images = [metadata['sd-metadata']['image']]

View File

@ -4,107 +4,191 @@ weighted subprompts.
Useful function exports: Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated log_tokenization() print out colour-coded tokens and warn if truncated
''' '''
import re import re
from difflib import SequenceMatcher
from typing import Union
import torch import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt # Extract Unconditioned Words From Prompt
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt) unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0: if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals) unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt # Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex) unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt) clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt = re.sub(' +', ' ', clean_prompt) prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
uc = model.get_learned_conditioning([unconditioned_words]) pp = PromptParser()
# get weighted sub-prompts parsed_prompt: Union[FlattenedPrompt, Blend] = None
weighted_subprompts = split_weighted_subprompts( legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
prompt, skip_normalize if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
# we don't support conjunctions for now
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None
cac_args:CrossAttentionControl.Arguments = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_opcodes = []
edit_options = []
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
)
) )
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens, weight)
c = torch.add(
c,
model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c)
def split_weighted_subprompts(text, skip_normalize=False)->list: def build_token_edit_opcodes(original_tokens, edited_tokens):
""" original_tokens = original_tokens.cpu().numpy()[0]
grabs all text up to the first occurrence of ':' edited_tokens = edited_tokens.cpu().numpy()[0]
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0 return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
repeats until no text remaining
""" def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
prompt_parser = re.compile(""" if type(flattened_prompt) is not FlattenedPrompt:
(?P<prompt> # capture group for 'prompt' raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' fragments = [x.text for x in flattened_prompt.children]
) # end 'prompt' weights = [x.weight for x in flattened_prompt.children]
(?: # non-capture group embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
:+ # match one or more ':' characters if not flattened_prompt.is_empty and log_tokens:
(?P<weight> # capture group for 'weight' start_token = model.cond_stage_model.tokenizer.bos_token_id
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number end_token = model.cond_stage_model.tokenizer.eos_token_id
)? # end weight capture group, make optional tokens_list = tokens[0].tolist()
\s* # strip spaces after weight if tokens_list[0] == start_token:
| # OR tokens_list[0] = '<start>'
$ # else, if no ':' then match end of line try:
) # end non-capture group first_end_token_index = tokens_list.index(end_token)
""", re.VERBOSE) tokens_list[first_end_token_index] = '<end>'
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( tokens_list = tokens_list[:first_end_token_index+1]
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] except ValueError:
if skip_normalize: pass
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts)) print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
if weight_sum == 0:
print( return embeddings, tokens
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1) def get_tokens_length(model, fragments: list[Fragment]):
return [(x[0], equal_weight) for x in parsed_prompts] fragment_texts = [x.text for x in fragments]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts] tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
assert isinstance(uncond, dict)
assert isinstance(cond, dict)
cond_flattened = dict()
for k in cond:
if isinstance(cond[k], list):
cond_flattened[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
import random import random
import os import os
import traceback
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from einops import rearrange, repeat from einops import rearrange, repeat
@ -29,6 +30,7 @@ class Generator():
self.variation_amount = 0 self.variation_amount = 0
self.with_variations = [] self.with_variations = []
self.use_mps_noise = False self.use_mps_noise = False
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py # this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs): def get_make_image(self,prompt,**kwargs):
@ -43,7 +45,7 @@ class Generator():
self.variation_amount = variation_amount self.variation_amount = variation_amount
self.with_variations = with_variations self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
**kwargs): **kwargs):
@ -51,6 +53,7 @@ class Generator():
self.safety_checker = safety_checker self.safety_checker = safety_checker
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler,
init_image = init_image, init_image = init_image,
width = width, width = width,
height = height, height = height,
@ -59,12 +62,14 @@ class Generator():
perlin = perlin, perlin = perlin,
**kwargs **kwargs
) )
results = [] results = []
seed = seed if seed is not None else self.new_seed() seed = seed if seed is not None else self.new_seed()
first_seed = seed first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height) seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
# There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type):
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc='Generating'):
x_T = None x_T = None
if self.variation_amount > 0: if self.variation_amount > 0:
@ -79,7 +84,8 @@ class Generator():
try: try:
x_T = self.get_noise(width,height) x_T = self.get_noise(width,height)
except: except:
pass print('** An error occurred while getting initial noise **')
print(traceback.format_exc())
image = make_image(x_T) image = make_image(x_T)
@ -95,10 +101,10 @@ class Generator():
return results return results
def sample_to_image(self,samples): def sample_to_image(self,samples)->Image.Image:
""" """
Returns a function returning an image derived from the prompt and the initial image Given samples returned from a sampler, converts
Return value depends on the seed at the time you call it it into a PIL Image
""" """
x_samples = self.model.decode_first_stage(samples) x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

View File

@ -21,6 +21,7 @@ class Embiggen(Generator):
def generate(self,prompt,iterations=1,seed=None, def generate(self,prompt,iterations=1,seed=None,
image_callback=None, step_callback=None, image_callback=None, step_callback=None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
@ -63,6 +64,8 @@ class Embiggen(Generator):
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
""" """
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments # Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling embiggen = [1.0] # If not specified, assume no scaling

View File

@ -10,6 +10,7 @@ from PIL import Image
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
@ -29,7 +30,7 @@ class Img2Img(Generator):
) )
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image) init_image = self._image_to_tensor(init_image.convert('RGB'))
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):
@ -38,7 +39,7 @@ class Img2Img(Generator):
) # move to latent space ) # move to latent space
t_enc = int(strength * steps) t_enc = int(strength * steps)
uc, c = conditioning uc, c, extra_conditioning_info = conditioning
def make_image(x_T): def make_image(x_T):
# encode (scaled latent) # encode (scaled latent)
@ -56,6 +57,8 @@ class Img2Img(Generator):
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info,
all_timesteps_count = steps
) )
return self.sample_to_image(samples) return self.sample_to_image(samples)
@ -77,6 +80,9 @@ class Img2Img(Generator):
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
if normalize: if normalize:

View File

@ -2,12 +2,13 @@
ldm.invoke.generator.inpaint descends from ldm.invoke.generator ldm.invoke.generator.inpaint descends from ldm.invoke.generator
''' '''
import math
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
import numpy as np import numpy as np
import cv2 as cv import cv2 as cv
import PIL import PIL
from PIL import Image, ImageFilter from PIL import Image, ImageFilter, ImageOps
from skimage.exposure.histogram_matching import match_histograms from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
@ -24,11 +25,128 @@ class Inpaint(Img2Img):
self.mask_blur_radius = 0 self.mask_blur_radius = 0
super().__init__(model, precision) super().__init__(model, precision)
# Outpaint support code
def get_tile_images(self, image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False
)
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy()
# Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0)
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
si = Image.fromarray(st, mode='RGBA')
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
# Combine
npmask = npgradient + npedge
# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
new_mask = Image.fromarray(npmask)
if edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
return ImageOps.invert(new_mask)
def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
strength = strength,
mask_blur_radius = 0,
seam_size = 0
)
result = make_image(noise)
return result
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength, conditioning,init_image,mask_image,strength,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs): # Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, **kwargs):
""" """
Returns a function returning an image derived from the prompt and Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at the initial image + mask. Return value depends on the seed at
@ -37,7 +155,17 @@ class Inpaint(Img2Img):
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image self.pil_image = init_image
init_image = self._image_to_tensor(init_image)
# Fill missing areas of original image
init_filled = self.tile_fill_missing(
self.pil_image.copy(),
seed = self.seed,
tile_size = tile_size
)
init_filled.paste(init_image, (0,0), init_image.split()[-1])
# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))
if isinstance(mask_image, PIL.Image.Image): if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image self.pil_mask = mask_image
@ -73,7 +201,8 @@ class Inpaint(Img2Img):
) # move to latent space ) # move to latent space
t_enc = int(strength * steps) t_enc = int(strength * steps)
uc, c = conditioning # todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps") print(f">> target t_enc is {t_enc} steps")
@ -105,38 +234,56 @@ class Inpaint(Img2Img):
mask = mask_image, mask = mask_image,
init_latent = self.init_latent init_latent = self.init_latent
) )
return self.sample_to_image(samples)
result = self.sample_to_image(samples)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T)
return result
return make_image return make_image
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
# Get the original alpha channel of the mask if there is one. # Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB') # Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching. # Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
# histogram and cause slight color changes. init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version # Get numpy version of result
np_gen_result = np.asarray(gen_result, dtype=np.uint8) np_image = np.asarray(image, dtype=np.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct # Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB') matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount # Blur the mask out (into init image) by specified amount
@ -149,6 +296,16 @@ class Inpaint(Img2Img):
blurred_init_mask = pil_init_mask blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask) # Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
return matched_result return matched_result
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
return corrected_result

View File

@ -0,0 +1,153 @@
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
import torch
import numpy as np
from einops import repeat
from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
class Omnibus(Img2Img,Txt2Img):
def __init__(self, model, precision):
super().__init__(model, precision)
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
init_image = None,
mask_image = None,
strength = None,
step_callback=None,
threshold=0.0,
perlin=0.0,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
num_samples = 1
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, Image.Image):
if init_image.mode != 'RGB':
init_image = init_image.convert('RGB')
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
if init_image is not None and mask_image is not None: # inpainting
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
else: # txt2img
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
self.init_latent = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
with torch.no_grad():
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd(
init_image,
mask_image,
masked_image,
prompt=prompt,
device=model.device,
num_samples=num_samples,
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond={"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = cond,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc_full,
eta = 1.0,
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
def make_batch_sd(
self,
image,
mask,
masked_image,
prompt,
device,
num_samples=1):
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [prompt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def get_noise(self, width:int, height:int):
if self.init_latent is not None:
height = self.init_latent.shape[2]
width = self.init_latent.shape[3]
return Txt2Img.get_noise(self,width,height)

View File

@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch import torch
import numpy as np import numpy as np
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img(Generator): class Txt2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
@ -19,7 +21,7 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
self.perlin = perlin self.perlin = perlin
uc, c = conditioning uc, c, extra_conditioning_info = conditioning
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):
@ -43,6 +45,7 @@ class Txt2Img(Generator):
verbose = False, verbose = False,
unconditional_guidance_scale = cfg_scale, unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc, unconditional_conditioning = uc,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback, img_callback = step_callback,
threshold = threshold, threshold = threshold,

View File

@ -7,7 +7,9 @@ import numpy as np
import math import math
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.generator.omnibus import Omnibus
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from PIL import Image
class Txt2Img2Img(Generator): class Txt2Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
@ -22,7 +24,7 @@ class Txt2Img2Img(Generator):
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
uc, c = conditioning uc, c, extra_conditioning_info = conditioning
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):
@ -59,7 +61,8 @@ class Txt2Img2Img(Generator):
unconditional_guidance_scale = cfg_scale, unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc, unconditional_conditioning = uc,
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback img_callback = step_callback,
extra_conditioning_info = extra_conditioning_info
) )
print( print(
@ -93,6 +96,8 @@ class Txt2Img2Img(Generator):
img_callback = step_callback, img_callback = step_callback,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
extra_conditioning_info=extra_conditioning_info,
all_timesteps_count=steps
) )
if self.free_gpu_mem: if self.free_gpu_mem:
@ -100,8 +105,49 @@ class Txt2Img2Img(Generator):
return self.sample_to_image(samples) return self.sample_to_image(samples)
return make_image # in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
def inpaint_make_image(x_T):
omnibus = Omnibus(self.model,self.precision)
result = omnibus.generate(
prompt,
sampler=sampler,
width=init_width,
height=init_height,
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
assert result is not None and len(result)>0,'** txt2img failed **'
image = result[0][0]
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
print(kwargs.pop('init_image',None))
result = omnibus.generate(
prompt,
sampler=sampler,
init_image=interpolated_image,
width=width,
height=height,
seed=result[0][1],
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
return result[0][0]
if sampler.uses_inpainting_model():
return inpaint_make_image
else:
return make_image
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True): def get_noise(self,width,height,scale = True):
@ -129,3 +175,4 @@ class Txt2Img2Img(Generator):
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
device=device) device=device)

View File

@ -13,6 +13,7 @@ import gc
import hashlib import hashlib
import psutil import psutil
import transformers import transformers
import traceback
import os import os
from sys import getrefcount from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -73,6 +74,7 @@ class ModelCache(object):
self.models[model_name]['hash'] = hash self.models[model_name]['hash'] = hash
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
self.get_model(self.current_model) self.get_model(self.current_model)
return None return None

693
ldm/invoke/prompt_parser.py Normal file
View File

@ -0,0 +1,693 @@
import string
from typing import Union, Optional
import re
import pyparsing as pp
class Prompt():
"""
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
that are to be blended together are stored individuall as Prompt objects.
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
to produce a FlattenedPrompt.
"""
def __init__(self, parts: list):
for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class BaseFragment:
pass
class FlattenedPrompt():
"""
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
"""
def __init__(self, parts: list=[]):
self.children = []
for part in parts:
self.append(part)
def append(self, fragment: Union[list, BaseFragment, tuple]):
# verify type correctness
if type(fragment) is list:
for x in fragment:
self.append(x)
elif issubclass(type(fragment), BaseFragment):
self.children.append(fragment)
elif type(fragment) is tuple:
# upgrade tuples to Fragments
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
self.children.append(Fragment(fragment[0], fragment[1]))
else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
@property
def is_empty(self):
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
class Fragment(BaseFragment):
"""
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
"""
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text:
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
self.text = text
self.weight = float(weight)
def __repr__(self):
return "Fragment:'"+self.text+"'@"+str(self.weight)
def __eq__(self, other):
return type(other) is Fragment \
and other.text == self.text \
and other.weight == self.weight
class Attention():
"""
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
"""
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
"""
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
the result should be an "edited" image that looks like the "original" image with concepts swapped.
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
or it may represent a larger portion of the token sequence:
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
edited=[Fragment('a smiling dog sitting on a car')])
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
FlattenedPrompt([
Fragment('a'),
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
Fragment('sitting on a car')
])
"""
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
self.original = original
self.edited = edited
default_options = {
's_start': 0.0,
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
't_start': 0.0,
't_end': 1.0
}
merged_options = default_options
if options is not None:
shape_freedom = options.pop('shape_freedom', None)
if shape_freedom is not None:
# high shape freedom = SD can do what it wants with the shape of the object
# high shape freedom => s_end = 0
# low shape freedom => s_end = 1
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
# and there is very little perceptible difference as s_end increases above 0.5
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
# -> cube root and subtract from 1.0
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
#print('converted shape_freedom argument to', merged_options)
merged_options.update(options)
self.options = merged_options
def __repr__(self):
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
and other.edited == self.edited \
and other.options == self.options
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment):
self.fragment = fragment
def __repr__(self):
return "CrossAttentionControlAppend:",self.fragment
def __eq__(self, other):
return type(other) is CrossAttentionControlAppend \
and other.fragment == self.fragment
class Conjunction():
"""
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
by weighted sum in latent space.
"""
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", parts)
self.prompts = [x if (type(x) is Prompt
or type(x) is Blend
or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
"""
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts.
"""
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for c in prompts:
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
def parse_conjunction(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a Conjunction representing the parsed results.
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.conjunction.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
def flatten(self, root: Conjunction) -> Conjunction:
"""
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure.
:param root: The Conjunction to flatten.
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
"""
#print("flattening", root)
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if type(x) is CrossAttentionControlSubstitute:
original_fused = fuse_fragments(x.original)
edited_fused = fuse_fragments(x.edited)
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
else None
this_text = x.text
this_weight = x.weight
if last_weight is not None and last_weight == this_weight:
last_text = result[-1].text
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for index,c in enumerate(node.children):
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
#print("flattened to", flattened_parts)
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
attention = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
def make_text_fragment(x):
#print("### making fragment for", x)
if type(x[0]) is Fragment:
assert(False)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
def build_escaped_word_parser(escaped_chars_to_ignore: str):
terms = []
for c in escaped_chars_to_ignore:
terms.append(pp.Literal('\\'+c))
terms.append(
#pp.CharsNotIn(string.whitespace + escaped_chars_to_ignore, exact=1)
pp.Word(pp.printables, exclude_chars=string.whitespace + escaped_chars_to_ignore)
)
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(terms)
))
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
escapes = []
for c in escaped_chars_to_ignore:
escapes.append(pp.Literal('\\'+c))
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(escapes + [pp.CharsNotIn(
string.whitespace + escaped_chars_to_ignore,
exact=1
)])
))
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
#print(f"parsing fragment string for {x}")
fragment_string = x[0]
#print(f"ppparsing fragment string \"{fragment_string}\"")
if len(fragment_string.strip()) == 0:
return Fragment('')
if in_quotes:
# escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"')
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
try:
result = pp.Group(pp.MatchFirst([
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
#print("parsed to", result)
return result
except pp.ParseException as e:
#print("parse_fragment_str couldn't parse prompt string:", e)
raise
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
def not_ends_with_swap(x):
#print("trying to match:", x)
return not x[0].endswith('.swap')
unquoted_word = pp.Combine(pp.OneOrMore(
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
(pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') + (pp.NotAny(pp.Word('+') | pp.Word('-'))))
))
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << (lparen +
pp.Or([
(parenthesized_fragment),
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
(pp.Combine(pp.OneOrMore(
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
pp.Word(string.whitespace)
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
pp.Empty()
]) + rparen)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_with_parens = pp.Forward()
attention_without_parens = pp.Forward()
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_foot")\
.set_debug(False)
attention_with_parens <<= pp.Group(
lparen +
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
)
+ rparen + attention_with_parens_foot)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens_foot = pp.NotAny(pp.White()) + pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots')
attention_without_parens <<= pp.Group(pp.MatchFirst([
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
+ attention_without_parens_foot#.leave_whitespace()
]))
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
attention << pp.MatchFirst([attention_with_parens,
attention_without_parens
])
attention.set_name('attention')
def make_attention(x):
#print("entered make_attention with", x)
children = x[0][:-1]
weight_raw = x[0][-1]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
#print("making Attention from", children, "with weight", weight)
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
attention_with_parens.set_parse_action(make_attention)
attention_without_parens.set_parse_action(make_attention)
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
# cross attention control
debug_cross_attention_control = False
original_fragment = pp.MatchFirst([
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
empty_string.set_debug(debug_cross_attention_control),
])
# support keyword=number arguments
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
edited_fragment = pp.MatchFirst([
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
lparen +
(quoted_fragment | attention |
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
) +
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
rparen,
parenthesized_fragment
])
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
#if len(x>2):
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# root prompt definition
debug_root_prompt = False
prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt),
attention.set_debug(debug_root_prompt),
quoted_fragment.set_debug(debug_root_prompt),
parenthesized_fragment.set_debug(debug_root_prompt),
unquoted_word.set_debug(debug_root_prompt),
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
) + pp.StringEnd()) \
.set_name('prompt') \
.set_parse_action(lambda x: Prompt(x)) \
.set_debug(debug_root_prompt)
#print("parsing test:", prompt.parse_string("spaced eyes--"))
#print("parsing test:", prompt.parse_string("eyes--"))
# weighted blend of prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
# int weights.
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
def make_prompt_from_quoted_string(x):
#print(' got quoted prompt', x)
x_unquoted = x[0][1:-1]
if len(x_unquoted.strip()) == 0:
# print(' b : just an empty string')
return Prompt([Fragment('')])
#print(f' b parsing \'{x_unquoted}\'')
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
debug_blend=False
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(debug_blend)
def make_blend(x):
prompts = x[0][0]
weights = x[0][1]
normalize = True
if weights[-1] == 'no_normalize':
normalize = False
weights = weights[:-1]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
blend.set_parse_action(make_blend)
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', 'x` ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -89,6 +89,9 @@ class Outcrop(object):
def _extend(self,image:Image,pixels:int)-> Image: def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels)) extended_img = Image.new('RGBA',(image.width,image.height+pixels))
mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \
else pixels *2
# first paste places old image at top of extended image, stretch # first paste places old image at top of extended image, stretch
# it, and applies a gaussian blur to it # it, and applies a gaussian blur to it
# take the top half region, stretch and paste it # take the top half region, stretch and paste it
@ -105,7 +108,9 @@ class Outcrop(object):
# now make the top part transparent to use as a mask # now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A') alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels*2)) alpha.paste(0,(0,0,extended_img.width,mask_height))
extended_img.putalpha(alpha) extended_img.putalpha(alpha)
extended_img.save('outputs/curly_extended.png')
return extended_img return extended_img

View File

@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

View File

@ -0,0 +1,238 @@
from enum import Enum
import torch
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class Context:
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.arguments = arguments
self.step_count = step_count
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model,
cross_attention_control_args: Arguments
):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = cross_attention_control_args.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_function(model)
class CrossAttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(cls.CrossAttentionType.TOKENS)
return to_control
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
@classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice:
this_attn_slice = attn_slice
if self.last_attn_slice_mask is not None:
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
if dim is None:
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0:
base_attn_slice = base_attn_slice_full[start:end]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1:
base_attn_slice = base_attn_slice_full[:, start:end]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
base_attn_slice * base_attn_slice_mask
else:
if dim is None:
attn_slice = self.last_attn_slice
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0:
attn_slice = self.last_attn_slice[start:end]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1:
attn_slice = self.last_attn_slice[:, start:end]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
elif dim == 0:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self.last_attn_slice.shape[0] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
assert(self.last_attn_slice.shape[0] == end)
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else:
# no need to grow
self.last_attn_slice[start:end] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
elif self.last_attn_slice.shape[1] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
assert(self.last_attn_slice.shape[1] == end)
else:
# no need to grow
self.last_attn_slice[:, start:end] = attn_slice
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
if dim is None:
weights = self.last_attn_slice_weights
elif dim == 0:
weights = self.last_attn_slice_weights[start:end]
elif dim == 1:
weights = self.last_attn_slice_weights[:, start:end]
attn_slice = attn_slice * weights
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)

View File

@ -1,10 +1,7 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
import torch import torch
import numpy as np from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like from ldm.modules.diffusionmodules.util import noise_like
@ -12,6 +9,21 @@ class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps,device) super().__init__(model,schedule,model.num_timesteps,device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.remove_cross_attention_control()
# This is the central routine # This is the central routine
@torch.no_grad() @torch.no_grad()
def p_sample( def p_sample(
@ -29,6 +41,7 @@ class DDIMSampler(Sampler):
corrector_kwargs=None, corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
step_count:int=1000, # total number of steps
**kwargs, **kwargs,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
@ -37,16 +50,17 @@ class DDIMSampler(Sampler):
unconditional_conditioning is None unconditional_conditioning is None
or unconditional_guidance_scale == 1.0 or unconditional_guidance_scale == 1.0
): ):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) # step_index counts in the opposite direction to index
t_in = torch.cat([t] * 2) step_index = step_count-(index+1)
c_in = torch.cat([unconditional_conditioning, c]) e_t = self.invokeai_diffuser.do_diffusion_step(
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) x, t,
e_t = e_t_uncond + unconditional_guidance_scale * ( unconditional_conditioning, c,
e_t - e_t_uncond unconditional_guidance_scale,
step_index=step_index
) )
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(

View File

@ -19,6 +19,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
import urllib import urllib
from ldm.util import ( from ldm.util import (
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self.model) self.model_ema = LitEma(self.model)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
self.use_scheduler = scheduler_config is not None self.use_scheduler = scheduler_config is not None
if self.use_scheduler: if self.use_scheduler:
@ -820,21 +821,21 @@ class LatentDiffusion(DDPM):
) )
return self.scale_factor * z return self.scale_factor * z
def get_learned_conditioning(self, c): def get_learned_conditioning(self, c, **kwargs):
if self.cond_stage_forward is None: if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable( if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode self.cond_stage_model.encode
): ):
c = self.cond_stage_model.encode( c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager c, embedding_manager=self.embedding_manager,**kwargs
) )
if isinstance(c, DiagonalGaussianDistribution): if isinstance(c, DiagonalGaussianDistribution):
c = c.mode() c = c.mode()
else: else:
c = self.cond_stage_model(c) c = self.cond_stage_model(c, **kwargs)
else: else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward) assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
return c return c
def meshgrid(self, h, w): def meshgrid(self, h, w):
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
return samples, intermediates return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad() @torch.no_grad()
def log_images( def log_images(
self, self,
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc) out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, context=cc) out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
cond_img = torch.stack(bbox_imgs, dim=0) cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

View File

@ -1,16 +1,16 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K import k_diffusion as K
import torch import torch
import torch.nn as nn from torch import nn
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler from .sampler import Sampler
from ldm.util import rand_perlin_2d from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps, # at this threshold, the scheduler will stop using the Karras
noise_like, # noise schedule and start using the model's schedule
extract_into_tensor, STEP_THRESHOLD = 29
)
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0: if threshold <= 0.0:
@ -33,12 +33,21 @@ class CFGDenoiser(nn.Module):
self.threshold = threshold self.threshold = threshold
self.warmup_max = warmup self.warmup_max = warmup
self.warmup = max(warmup / 10, 1) self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1 self.warmup += 1
@ -46,8 +55,7 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold thresh = self.threshold
if thresh > self.threshold: if thresh > self.threshold:
thresh = self.threshold thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler): class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs): def __init__(self, model, schedule='lms', device=None, **kwargs):
@ -60,16 +68,9 @@ class KSampler(Sampler):
self.sigmas = None self.sigmas = None
self.ds = None self.ds = None
self.s_in = None self.s_in = None
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
def forward(self, x, sigma, uncond, cond, cond_scale): if self.karras_max is None:
x_in = torch.cat([x] * 2) self.karras_max = STEP_THRESHOLD
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
def make_schedule( def make_schedule(
self, self,
@ -98,8 +99,13 @@ class KSampler(Sampler):
rho=7., rho=7.,
device=self.device, device=self.device,
) )
if ddim_num_steps >= self.karras_max:
print(f'>> Ksampler using model noise schedule (steps > {self.karras_max})')
self.sigmas = self.model_sigmas self.sigmas = self.model_sigmas
#self.sigmas = self.karras_sigmas else:
print(f'>> Ksampler using karras noise schedule (steps <= {self.karras_max})')
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which # ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to # means that inpainting will not work. To get this to work we need to be able to
@ -118,6 +124,7 @@ class KSampler(Sampler):
use_original_steps=False, use_original_steps=False,
init_latent = None, init_latent = None,
mask = None, mask = None,
**kwargs
): ):
samples,_ = self.sample( samples,_ = self.sample(
batch_size = 1, batch_size = 1,
@ -129,7 +136,8 @@ class KSampler(Sampler):
unconditional_conditioning = unconditional_conditioning, unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback, img_callback = img_callback,
x0 = init_latent, x0 = init_latent,
mask = mask mask = mask,
**kwargs
) )
return samples return samples
@ -163,6 +171,7 @@ class KSampler(Sampler):
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info=None,
threshold = 0, threshold = 0,
perlin = 0, perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -181,7 +190,6 @@ class KSampler(Sampler):
) )
# sigmas are set up in make_schedule - we take the last steps items # sigmas are set up in make_schedule - we take the last steps items
total_steps = len(self.sigmas)
sigmas = self.sigmas[-S-1:] sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add # x_T is variation noise. When an init image is provided (in x0) we need to add
@ -195,19 +203,21 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale, 'cond_scale': unconditional_guidance_scale,
} }
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
return ( sampling_result = (
K.sampling.__dict__[f'sample_{self.schedule}']( K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args, model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback callback=route_callback
), ),
None, None,
) )
return sampling_result
# this code will support inpainting if and when ksampler API modified or # this code will support inpainting if and when ksampler API modified or
# a workaround is found. # a workaround is found.
@ -220,6 +230,7 @@ class KSampler(Sampler):
index, index,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info=None,
**kwargs, **kwargs,
): ):
if self.model_wrap is None: if self.model_wrap is None:
@ -245,6 +256,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas: # so the actual formula for indexing into sigmas:
# sigma_index = (steps-index) # sigma_index = (steps-index)
s_index = t_enc - index - 1 s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
img = K.sampling.__dict__[f'_{self.schedule}']( img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap, self.model_wrap,
img, img,
@ -269,7 +281,7 @@ class KSampler(Sampler):
else: else:
return x return x
def prepare_to_sample(self,t_enc): def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc self.t_enc = t_enc
self.model_wrap = None self.model_wrap = None
self.ds = None self.ds = None
@ -281,3 +293,6 @@ class KSampler(Sampler):
''' '''
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key

View File

@ -5,6 +5,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like from ldm.modules.diffusionmodules.util import noise_like
@ -13,6 +14,18 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device) super().__init__(model,schedule,model.num_timesteps, device)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.remove_cross_attention_control()
# this is the essential routine # this is the essential routine
@torch.no_grad() @torch.no_grad()
def p_sample( def p_sample(
@ -32,6 +45,7 @@ class PLMSSampler(Sampler):
unconditional_conditioning=None, unconditional_conditioning=None,
old_eps=[], old_eps=[],
t_next=None, t_next=None,
step_count:int=1000, # total number of steps
**kwargs, **kwargs,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
@ -41,18 +55,15 @@ class PLMSSampler(Sampler):
unconditional_conditioning is None unconditional_conditioning is None
or unconditional_guidance_scale == 1.0 or unconditional_guidance_scale == 1.0
): ):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) # step_index counts in the opposite direction to index
t_in = torch.cat([t] * 2) step_index = step_count-(index+1)
c_in = torch.cat([unconditional_conditioning, c]) e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
e_t_uncond, e_t = self.model.apply_model( unconditional_conditioning, c,
x_in, t_in, c_in unconditional_guidance_scale,
).chunk(2) step_index=step_index)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(

View File

@ -2,13 +2,13 @@
ldm.models.diffusion.sampler ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
''' '''
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import ( from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters, make_ddim_sampling_parameters,
@ -24,6 +24,8 @@ class Sampler(object):
self.ddpm_num_timesteps = steps self.ddpm_num_timesteps = steps
self.schedule = schedule self.schedule = schedule
self.device = device or choose_torch_device() self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
@ -158,6 +160,18 @@ class Sampler(object):
**kwargs, **kwargs,
): ):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# check to see if make_schedule() has run, and if not, run it # check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None: if self.ddim_timesteps is None:
self.make_schedule( self.make_schedule(
@ -190,10 +204,11 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
steps=S, steps=S,
**kwargs
) )
return samples, intermediates return samples, intermediates
#torch.no_grad() @torch.no_grad()
def do_sampling( def do_sampling(
self, self,
cond, cond,
@ -214,6 +229,7 @@ class Sampler(object):
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
steps=None, steps=None,
**kwargs
): ):
b = shape[0] b = shape[0]
time_range = ( time_range = (
@ -231,7 +247,7 @@ class Sampler(object):
dynamic_ncols=True, dynamic_ncols=True,
) )
old_eps = [] old_eps = []
self.prepare_to_sample(t_enc=total_steps) self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
img = self.get_initial_image(x_T,shape,total_steps) img = self.get_initial_image(x_T,shape,total_steps)
# probably don't need this at all # probably don't need this at all
@ -274,6 +290,7 @@ class Sampler(object):
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, old_eps=old_eps,
t_next=ts_next, t_next=ts_next,
step_count=steps
) )
img, pred_x0, e_t = outs img, pred_x0, e_t = outs
@ -305,8 +322,9 @@ class Sampler(object):
use_original_steps=False, use_original_steps=False,
init_latent = None, init_latent = None,
mask = None, mask = None,
all_timesteps_count = None,
**kwargs
): ):
timesteps = ( timesteps = (
np.arange(self.ddpm_num_timesteps) np.arange(self.ddpm_num_timesteps)
if use_original_steps if use_original_steps
@ -321,7 +339,7 @@ class Sampler(object):
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent x_dec = x_latent
x0 = init_latent x0 = init_latent
self.prepare_to_sample(t_enc=total_steps) self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
@ -353,6 +371,7 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
t_next = ts_next, t_next = ts_next,
step_count=len(self.ddim_timesteps)
) )
x_dec, pred_x0, e_t = outs x_dec, pred_x0, e_t = outs
@ -411,3 +430,21 @@ class Sampler(object):
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
''' '''
return self.model.q_sample(x0,ts) return self.model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.model.conditioning_key
def uses_inpainting_model(self)->bool:
return self.conditioning_key() in ('hybrid','concat')
def adjust_settings(self,**kwargs):
'''
This is a catch-all method for adjusting any instance variables
after the sampler is instantiated. No type-checking performed
here, so use with care!
'''
for k in kwargs.keys():
try:
setattr(self,k,kwargs[k])
except AttributeError:
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')

View File

@ -0,0 +1,216 @@
from math import ceil
from typing import Callable, Optional, Union
import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
class InvokeAIDiffuserComponent:
'''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
'''
class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
self.cross_attention_control_args = cross_attention_control_args
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.model = model
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model,
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: Optional[int]=None
):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
# methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = dict()
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]])
for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model)
return unconditioned_next_x, conditioned_next_x
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -1,5 +1,7 @@
from inspect import isfunction from inspect import isfunction
import math import math
from typing import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
return x+h_ return x+h_
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__() super().__init__()
@ -170,46 +173,71 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(self, q, k, v): self.attention_slice_wrangler = None
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
def einsum_op_slice_0(self, q, k, v, slice_size): def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention
suggested_attention_slice is a softmax(dim=-1) over attention_scores
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If dim is >= 0, offset and slice_size specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None:
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size): for i in range(0, q.shape[0], slice_size):
end = i + slice_size end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r return r
def einsum_op_slice_1(self, q, k, v, slice_size): def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r return r
def einsum_op_mps_v1(self, q, k, v): def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_op_compvis(q, k, v) return self.einsum_lowest_level(q, k, v, None, None, None)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_1(q, k, v, slice_size) return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v): def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096: if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v) return self.einsum_lowest_level(q, k, v, None, None, None)
else: else:
return self.einsum_op_slice_0(q, k, v, 1) return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb: if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v) return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]: if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v): def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device) stats = torch.cuda.memory_stats(q.device)
@ -221,7 +249,7 @@ class CrossAttention(nn.Module):
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(self, q, k, v): def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v) return self.einsum_op_cuda(q, k, v)
@ -244,8 +272,13 @@ class CrossAttention(nn.Module):
del context, x del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) r = self.get_attention_mem_efficient(q, k, v)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):

View File

@ -1,3 +1,5 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
@ -454,6 +456,223 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self(text, **kwargs) return self(text, **kwargs)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights"
return_tokens_key = "return_tokens"
def forward(self, text: list, **kwargs):
'''
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
weights shall be applied.
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
'''
if self.fragment_weights_key not in kwargs:
# fallback to base class implementation
return super().forward(text, **kwargs)
fragment_weights = kwargs[self.fragment_weights_key]
# self.transformer doesn't like receiving "fragment_weights" as an argument
kwargs.pop(self.fragment_weights_key)
should_return_tokens = False
if self.return_tokens_key in kwargs:
should_return_tokens = kwargs.get(self.return_tokens_key, False)
# self.transformer doesn't like having extra kwargs
kwargs.pop(self.return_tokens_key)
batch_z = None
batch_tokens = None
for fragments, weights in zip(text, fragment_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
# "red" is to tell SD that it should almost completely *ignore* redness).
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
# handle weights >=1
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
# this is our starting point
embeddings = base_embedding.unsqueeze(0)
per_embedding_weights = [1.0]
# now handle weights <1
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
# removed.
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights):
if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:]
weights_without_this = weights[:index] + weights[index+1:]
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore:
# fragment_weight = 1: we are at base_z => lerp weight 0
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
# so let's use tan(), because:
# tan is 0.0 at 0,
# 1.0 at PI/4, and
# inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
# todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
# should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
if should_return_tokens:
return batch_z, batch_tokens
else:
return batch_z
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
tokens = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
if include_start_and_end_markers:
return tokens
else:
return [x[1:-1] for x in tokens]
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768)
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
'''
:param fragments:
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
:return:
'''
# empty is meaningful
if len(fragments) == 0 and len(weights) == 0:
fragments = ['']
weights = [1]
item_encodings = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=True,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
all_tokens = []
per_token_weights = []
#print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings):
weight = weights[index]
#print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index]
#print("fragment", fragment, "processed to", fragment_tokens)
# trim bos and eos markers before appending
all_tokens.extend(fragment_tokens[1:-1])
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
if (len(all_tokens) + 2) > self.max_length:
excess_token_count = (len(all_tokens) + 2) - self.max_length
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2]
per_token_weights = per_token_weights[:self.max_length - 2]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length)
pad_length = self.max_length - 1 - len(all_tokens)
all_tokens.insert(0, self.tokenizer.bos_token_id)
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
per_token_weights.insert(0, 1)
per_token_weights.extend([1] * pad_length)
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
'''
Build a tensor representing the passed-in tokens, each of which has a weight.
:param tokens: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
'''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
if weight_delta_from_empty:
empty_tokens = self.tokenizer([''] * z.shape[0],
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)['input_ids'].to(self.device)
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z)
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5])
return weighted_z
else:
original_mean = z.mean()
z *= batch_weights_expanded
after_weighting_mean = z.mean()
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
mean_correction_factor = original_mean/after_weighting_mean
z *= mean_correction_factor
return z
class FrozenCLIPTextEmbedder(nn.Module): class FrozenCLIPTextEmbedder(nn.Module):
""" """

View File

@ -18,6 +18,7 @@ from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from pyparsing import ParseException
# global used in multiple functions (fix) # global used in multiple functions (fix)
infile = None infile = None
@ -172,8 +173,7 @@ def main_loop(gen, opt):
pass pass
if len(opt.prompt) == 0: if len(opt.prompt) == 0:
print('\nTry again with a prompt!') opt.prompt = ''
continue
# width and height are set by model if not specified # width and height are set by model if not specified
if not opt.width: if not opt.width:
@ -328,12 +328,16 @@ def main_loop(gen, opt):
if operation == 'generate': if operation == 'generate':
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate' opt.last_operation='generate'
try:
gen.prompt2image( gen.prompt2image(
image_callback=image_writer, image_callback=image_writer,
step_callback=step_callback, step_callback=step_callback,
catch_interrupts=catch_ctrl_c, catch_interrupts=catch_ctrl_c,
**vars(opt) **vars(opt)
) )
except ParseException as e:
print('** An error occurred while processing your prompt **')
print(f'** {str(e)} **')
elif operation == 'postprocess': elif operation == 'postprocess':
print(f'>> fixing {opt.prompt}') print(f'>> fixing {opt.prompt}')
opt.last_operation = do_postprocess(gen,opt,image_writer) opt.last_operation = do_postprocess(gen,opt,image_writer)
@ -592,7 +596,9 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
def do_textmask(gen, opt, callback): def do_textmask(gen, opt, callback):
image_path = opt.prompt image_path = opt.prompt
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **' if not os.path.exists(image_path):
image_path = os.path.join(opt.outdir,image_path)
assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **'
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **' assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
tm = opt.text_mask[0] tm = opt.text_mask[0]
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5

446
tests/test_prompt_parser.py Normal file
View File

@ -0,0 +1,446 @@
import unittest
import pyparsing
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
Fragment
def parse_prompt(prompt_string):
pp = PromptParser()
#print(f"parsing '{prompt_string}'")
parse_result = pp.parse_conjunction(prompt_string)
#print(f"-> parsed '{prompt_string}' to {parse_result}")
return parse_result
def make_basic_conjunction(strings: list[str]):
fragments = [Fragment(x) for x in strings]
return Conjunction([FlattenedPrompt(fragments)])
def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]):
fragments = [Fragment(x, w) for x,w in weighted_strings]
return Conjunction([FlattenedPrompt(fragments)])
class PromptParserTestCase(unittest.TestCase):
def test_empty(self):
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
def test_basic(self):
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
parse_prompt("(pretty flowers)+"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
parse_prompt("(pretty flowers)+, the flames are too hot"))
def test_no_parens_attention_runon(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
def test_explicit_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
self.assertEqual(
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
def test_conjunction_weights(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
with self.assertRaises(PromptParser.ParsingException):
parse_prompt('("fire", "flames").and(2)')
parse_prompt('("fire", "flames").and(2,1,2)')
def test_complex_conjunction(self):
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
def assert_if_prompt_string_not_untouched(prompt):
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
assert_if_prompt_string_not_untouched('a test prompt')
# todo handle this
#assert_if_prompt_string_not_untouched('a badly formed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed test prompt')
#with self.assertRaises(pyparsing.ParseException):
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test prompt')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(((a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(a (ba)dly (f)ormed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('("((a badly (formed +test ").blend(1.0)')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger ((bun))"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger (bun)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
parse_prompt("hamburger (kaiser roll)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
parse_prompt("hamburger ((kaiser roll))"))
def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
)
self.assertEqual(Conjunction([Blend(
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
)
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
FlattenedPrompt([('hi', 1.0)])],
weights=[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
)
# blend a single entry is not a failure
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
parse_prompt("(\"fire\").blend(0.7)")
)
# blend with empty
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
)
def test_nested(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
parse_prompt('fire (flames (trees)1.5)2.0'))
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
weights=[1.0, 1.0])]),
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
def test_cross_attention_control(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
(', fire', 1.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap("")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(0.9,1))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
def test_cross_attention_control_options(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
self.assertEqual(
Conjunction([
FlattenedPrompt([Fragment('a fantasy forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)],
options={'s_start': 0.8, 't_start': 0.8})])]),
parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)"))
def test_escaping(self):
# make sure ", ( and ) can be escaped
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
def test_cross_attention_escaping(self):
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (man).swap(monkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (man).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain ("man").swap(monkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain ("man").swap("monkey")'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (\\"man).swap("monkey")'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (man).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
def test_legacy_blend(self):
pp = PromptParser()
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain man:1 man mountain:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man:1 man mountain-:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man:1 man mountain-'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man: man mountain-:'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.75,0.25]),
pp.parse_legacy_blend('mountain man:3 man mountain:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[1.0,0.0]),
pp.parse_legacy_blend('mountain man:3 man mountain:0'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.8,0.2]),
pp.parse_legacy_blend('"mountain man":4 man mountain'))
def test_single(self):
# todo handle this
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
# parse_prompt('a badly formed +test prompt'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
pass
if __name__ == '__main__':
unittest.main()