Merge branch 'development' into main
3
.dockerignore
Normal file
@ -0,0 +1,3 @@
|
||||
*
|
||||
!environment*.yml
|
||||
!docker-build
|
43
.github/workflows/build-container.yml
vendored
Normal file
@ -0,0 +1,43 @@
|
||||
# Building the Image without pushing to confirm it is still buildable
|
||||
# confirum functionality would unfortunately need way more resources
|
||||
name: build container image
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
pull_request:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: prepare docker-tag
|
||||
env:
|
||||
repository: ${{ github.repository }}
|
||||
run: echo "dockertag=${repository,,}" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-${{ github.sha }}
|
||||
restore-keys: ${{ runner.os }}-buildx-
|
||||
- name: Build container
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
file: docker-build/Dockerfile
|
||||
platforms: linux/amd64
|
||||
push: false
|
||||
tags: ${{ env.dockertag }}:latest
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache
|
24
.github/workflows/create-caches.yml
vendored
@ -54,27 +54,9 @@ jobs:
|
||||
[[ -d models/ldm/stable-diffusion-v1 ]] \
|
||||
|| mkdir -p models/ldm/stable-diffusion-v1
|
||||
[[ -r models/ldm/stable-diffusion-v1/model.ckpt ]] \
|
||||
|| curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||
|
||||
- name: Use cached Conda Environment
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-conda-env-${{ env.CONDA_ENV_NAME }}
|
||||
conda-env-file: ${{ matrix.environment-file }}
|
||||
with:
|
||||
path: ${{ env.CONDA_ROOT }}/envs/${{ env.CONDA_ENV_NAME }}
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }}
|
||||
|
||||
- name: Use cached Conda Packages
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-conda-env-${{ env.CONDA_ENV_NAME }}
|
||||
conda-env-file: ${{ matrix.environment-file }}
|
||||
with:
|
||||
path: ${{ env.CONDA_PKGS_DIR }}
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }}
|
||||
|| curl --user "${{ secrets.HUGGINGFACE_TOKEN }}" \
|
||||
-o models/ldm/stable-diffusion-v1/model.ckpt \
|
||||
-O -L https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
|
||||
- name: Activate Conda Env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
|
32
.github/workflows/test-invoke-conda.yml
vendored
@ -4,8 +4,7 @@ on:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'fix-gh-actions-fork'
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
@ -13,15 +12,18 @@ on:
|
||||
jobs:
|
||||
os_matrix:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
environment-file: environment.yml
|
||||
default-shell: bash -l {0}
|
||||
stable-diffusion-model: https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
- os: macos-latest
|
||||
environment-file: environment-mac.yml
|
||||
default-shell: bash -l {0}
|
||||
stable-diffusion-model: https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
name: Test invoke.py on ${{ matrix.os }} with conda
|
||||
runs-on: ${{ matrix.os }}
|
||||
defaults:
|
||||
@ -48,7 +50,7 @@ jobs:
|
||||
|
||||
- name: set test prompt to Pull Request validation
|
||||
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
||||
run: echo "TEST_PROMPTS=tests/pr_prompt.txt" >> $GITHUB_ENV
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV
|
||||
|
||||
- name: set conda environment name
|
||||
run: echo "CONDA_ENV_NAME=invokeai" >> $GITHUB_ENV
|
||||
@ -57,7 +59,7 @@ jobs:
|
||||
id: cache-sd-v1-4
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-sd-v1-4
|
||||
cache-name: cache-sd-${{ matrix.stable-diffusion-model }}
|
||||
with:
|
||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
key: ${{ env.cache-name }}
|
||||
@ -69,25 +71,9 @@ jobs:
|
||||
[[ -d models/ldm/stable-diffusion-v1 ]] \
|
||||
|| mkdir -p models/ldm/stable-diffusion-v1
|
||||
[[ -r models/ldm/stable-diffusion-v1/model.ckpt ]] \
|
||||
|| curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||
|
||||
- name: Use cached Conda Environment
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-conda-env-${{ env.CONDA_ENV_NAME }}
|
||||
conda-env-file: ${{ matrix.environment-file }}
|
||||
with:
|
||||
path: ${{ env.CONDA }}/envs/${{ env.CONDA_ENV_NAME }}
|
||||
key: env-${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }}
|
||||
|
||||
- name: Use cached Conda Packages
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-conda-pkgs-${{ env.CONDA_ENV_NAME }}
|
||||
conda-env-file: ${{ matrix.environment-file }}
|
||||
with:
|
||||
path: ${{ env.CONDA_PKGS_DIR }}
|
||||
key: pkgs-${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }}
|
||||
|| curl --user "${{ secrets.HUGGINGFACE_TOKEN }}" \
|
||||
-o models/ldm/stable-diffusion-v1/model.ckpt \
|
||||
-O -L ${{ matrix.stable-diffusion-model }}
|
||||
|
||||
- name: Activate Conda Env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
|
4
.gitignore
vendored
@ -3,6 +3,10 @@ outputs/
|
||||
models/ldm/stable-diffusion-v1/model.ckpt
|
||||
ldm/invoke/restoration/codeformer/weights
|
||||
|
||||
# ignore user models config
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
|
||||
|
@ -14,7 +14,7 @@ from threading import Event
|
||||
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
|
||||
|
@ -33,7 +33,7 @@ from ldm.generate import Generate
|
||||
from ldm.invoke.restoration import Restoration
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
|
||||
from modules.parameters import parameters_to_command
|
||||
|
||||
|
@ -13,6 +13,13 @@ stable-diffusion-1.4:
|
||||
width: 512
|
||||
height: 512
|
||||
default: true
|
||||
inpainting-1.5:
|
||||
description: runwayML tuned inpainting model v1.5
|
||||
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
stable-diffusion-1.5:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
|
@ -76,4 +76,4 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
|
79
configs/stable-diffusion/v1-inpainting-inference.yaml
Normal file
@ -0,0 +1,79 @@
|
||||
model:
|
||||
base_learning_rate: 7.5e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 1
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
@ -1,57 +1,74 @@
|
||||
FROM debian
|
||||
FROM ubuntu AS get_miniconda
|
||||
|
||||
ARG gsd
|
||||
ENV GITHUB_STABLE_DIFFUSION $gsd
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ARG rsd
|
||||
ENV REQS $rsd
|
||||
# install wget
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
wget \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ARG cs
|
||||
ENV CONDA_SUBDIR $cs
|
||||
# download and install miniconda
|
||||
ARG conda_version=py39_4.12.0-Linux-x86_64
|
||||
ARG conda_prefix=/opt/conda
|
||||
RUN wget --progress=dot:giga -O /miniconda.sh \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-${conda_version}.sh \
|
||||
&& bash /miniconda.sh -b -p ${conda_prefix} \
|
||||
&& rm -f /miniconda.sh
|
||||
|
||||
ENV PIP_EXISTS_ACTION="w"
|
||||
FROM ubuntu AS invokeai
|
||||
|
||||
# TODO: Optimize image size
|
||||
# use bash
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
# clean bashrc
|
||||
RUN echo "" > ~/.bashrc
|
||||
|
||||
WORKDIR /
|
||||
RUN apt update && apt upgrade -y \
|
||||
&& apt install -y \
|
||||
git \
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
pip \
|
||||
python3 \
|
||||
&& git clone $GITHUB_STABLE_DIFFUSION
|
||||
# Install necesarry packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
--no-install-recommends \
|
||||
gcc \
|
||||
git \
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
pip \
|
||||
python3 \
|
||||
python3-dev \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Anaconda or Miniconda
|
||||
COPY anaconda.sh .
|
||||
RUN bash anaconda.sh -b -u -p /anaconda && /anaconda/bin/conda init bash
|
||||
# clone repository and create symlinks
|
||||
ARG invokeai_git=https://github.com/invoke-ai/InvokeAI.git
|
||||
ARG project_name=invokeai
|
||||
RUN git clone ${invokeai_git} /${project_name} \
|
||||
&& mkdir /${project_name}/models/ldm/stable-diffusion-v1 \
|
||||
&& ln -s /data/models/sd-v1-4.ckpt /${project_name}/models/ldm/stable-diffusion-v1/model.ckpt \
|
||||
&& ln -s /data/outputs/ /${project_name}/outputs
|
||||
|
||||
# SD
|
||||
WORKDIR /stable-diffusion
|
||||
RUN source ~/.bashrc \
|
||||
&& conda create -y --name ldm && conda activate ldm \
|
||||
&& conda config --env --set subdir $CONDA_SUBDIR \
|
||||
&& pip3 install -r $REQS \
|
||||
&& pip3 install basicsr facexlib realesrgan \
|
||||
&& mkdir models/ldm/stable-diffusion-v1 \
|
||||
&& ln -s "/data/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
|
||||
# set workdir
|
||||
WORKDIR /${project_name}
|
||||
|
||||
# Face restoreation
|
||||
# by default expected in a sibling directory to stable-diffusion
|
||||
WORKDIR /
|
||||
RUN git clone https://github.com/TencentARC/GFPGAN.git
|
||||
# install conda env and preload models
|
||||
ARG conda_prefix=/opt/conda
|
||||
ARG conda_env_file=environment.yml
|
||||
COPY --from=get_miniconda ${conda_prefix} ${conda_prefix}
|
||||
RUN source ${conda_prefix}/etc/profile.d/conda.sh \
|
||||
&& conda init bash \
|
||||
&& source ~/.bashrc \
|
||||
&& conda env create \
|
||||
--name ${project_name} \
|
||||
--file ${conda_env_file} \
|
||||
&& rm -Rf ~/.cache \
|
||||
&& conda clean -afy \
|
||||
&& echo "conda activate ${project_name}" >> ~/.bashrc \
|
||||
&& ln -s /data/models/GFPGANv1.4.pth ./src/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth \
|
||||
&& conda activate ${project_name} \
|
||||
&& python scripts/preload_models.py
|
||||
|
||||
WORKDIR /GFPGAN
|
||||
RUN pip3 install -r requirements.txt \
|
||||
&& python3 setup.py develop \
|
||||
&& ln -s "/data/GFPGANv1.4.pth" experiments/pretrained_models/GFPGANv1.4.pth
|
||||
|
||||
WORKDIR /stable-diffusion
|
||||
RUN python3 scripts/preload_models.py
|
||||
|
||||
WORKDIR /
|
||||
COPY entrypoint.sh .
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
# Copy entrypoint and set env
|
||||
ENV CONDA_PREFIX=${conda_prefix}
|
||||
ENV PROJECT_NAME=${project_name}
|
||||
COPY docker-build/entrypoint.sh /
|
||||
ENTRYPOINT [ "/entrypoint.sh" ]
|
||||
|
81
docker-build/build.sh
Executable file
@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
# IMPORTANT: You need to have a token on huggingface.co to be able to download the checkpoint!!!
|
||||
# configure values by using env when executing build.sh
|
||||
# f.e. env ARCH=aarch64 GITHUB_INVOKE_AI=https://github.com/yourname/yourfork.git ./build.sh
|
||||
|
||||
source ./docker-build/env.sh || echo "please run from repository root" || exit 1
|
||||
|
||||
invokeai_conda_version=${INVOKEAI_CONDA_VERSION:-py39_4.12.0-${platform/\//-}}
|
||||
invokeai_conda_prefix=${INVOKEAI_CONDA_PREFIX:-\/opt\/conda}
|
||||
invokeai_conda_env_file=${INVOKEAI_CONDA_ENV_FILE:-environment.yml}
|
||||
invokeai_git=${INVOKEAI_GIT:-https://github.com/invoke-ai/InvokeAI.git}
|
||||
huggingface_token=${HUGGINGFACE_TOKEN?}
|
||||
|
||||
# print the settings
|
||||
echo "You are using these values:"
|
||||
echo -e "project_name:\t\t ${project_name}"
|
||||
echo -e "volumename:\t\t ${volumename}"
|
||||
echo -e "arch:\t\t\t ${arch}"
|
||||
echo -e "platform:\t\t ${platform}"
|
||||
echo -e "invokeai_conda_version:\t ${invokeai_conda_version}"
|
||||
echo -e "invokeai_conda_prefix:\t ${invokeai_conda_prefix}"
|
||||
echo -e "invokeai_conda_env_file: ${invokeai_conda_env_file}"
|
||||
echo -e "invokeai_git:\t\t ${invokeai_git}"
|
||||
echo -e "invokeai_tag:\t\t ${invokeai_tag}\n"
|
||||
|
||||
_runAlpine() {
|
||||
docker run \
|
||||
--rm \
|
||||
--interactive \
|
||||
--tty \
|
||||
--mount source="$volumename",target=/data \
|
||||
--workdir /data \
|
||||
alpine "$@"
|
||||
}
|
||||
|
||||
_copyCheckpoints() {
|
||||
echo "creating subfolders for models and outputs"
|
||||
_runAlpine mkdir models
|
||||
_runAlpine mkdir outputs
|
||||
echo -n "downloading sd-v1-4.ckpt"
|
||||
_runAlpine wget --header="Authorization: Bearer ${huggingface_token}" -O models/sd-v1-4.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
echo "done"
|
||||
echo "downloading GFPGANv1.4.pth"
|
||||
_runAlpine wget -O models/GFPGANv1.4.pth https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
|
||||
}
|
||||
|
||||
_checkVolumeContent() {
|
||||
_runAlpine ls -lhA /data/models
|
||||
}
|
||||
|
||||
_getModelMd5s() {
|
||||
_runAlpine \
|
||||
alpine sh -c "md5sum /data/models/*"
|
||||
}
|
||||
|
||||
if [[ -n "$(docker volume ls -f name="${volumename}" -q)" ]]; then
|
||||
echo "Volume already exists"
|
||||
if [[ -z "$(_checkVolumeContent)" ]]; then
|
||||
echo "looks empty, copying checkpoint"
|
||||
_copyCheckpoints
|
||||
fi
|
||||
echo "Models in ${volumename}:"
|
||||
_checkVolumeContent
|
||||
else
|
||||
echo -n "createing docker volume "
|
||||
docker volume create "${volumename}"
|
||||
_copyCheckpoints
|
||||
fi
|
||||
|
||||
# Build Container
|
||||
docker build \
|
||||
--platform="${platform}" \
|
||||
--tag "${invokeai_tag}" \
|
||||
--build-arg project_name="${project_name}" \
|
||||
--build-arg conda_version="${invokeai_conda_version}" \
|
||||
--build-arg conda_prefix="${invokeai_conda_prefix}" \
|
||||
--build-arg conda_env_file="${invokeai_conda_env_file}" \
|
||||
--build-arg invokeai_git="${invokeai_git}" \
|
||||
--file ./docker-build/Dockerfile \
|
||||
.
|
@ -1,10 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
cd /stable-diffusion
|
||||
source "${CONDA_PREFIX}/etc/profile.d/conda.sh"
|
||||
conda activate "${PROJECT_NAME}"
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
python3 scripts/dream.py --full_precision -o /data
|
||||
# bash
|
||||
else
|
||||
python3 scripts/dream.py --full_precision -o /data "$@"
|
||||
fi
|
||||
python scripts/invoke.py \
|
||||
${@:---web --host=0.0.0.0}
|
||||
|
13
docker-build/env.sh
Normal file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
project_name=${PROJECT_NAME:-invokeai}
|
||||
volumename=${VOLUMENAME:-${project_name}_data}
|
||||
arch=${ARCH:-x86_64}
|
||||
platform=${PLATFORM:-Linux/${arch}}
|
||||
invokeai_tag=${INVOKEAI_TAG:-${project_name}-${arch}}
|
||||
|
||||
export project_name
|
||||
export volumename
|
||||
export arch
|
||||
export platform
|
||||
export invokeai_tag
|
15
docker-build/run.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
source ./docker-build/env.sh || echo "please run from repository root" || exit 1
|
||||
|
||||
docker run \
|
||||
--interactive \
|
||||
--tty \
|
||||
--rm \
|
||||
--platform "$platform" \
|
||||
--name "$project_name" \
|
||||
--hostname "$project_name" \
|
||||
--mount source="$volumename",target=/data \
|
||||
--publish 9090:9090 \
|
||||
"$invokeai_tag" ${1:+$@}
|
BIN
docs/assets/prompt_syntax/apricots--1.png
Normal file
After Width: | Height: | Size: 587 KiB |
BIN
docs/assets/prompt_syntax/apricots--2.png
Normal file
After Width: | Height: | Size: 572 KiB |
BIN
docs/assets/prompt_syntax/apricots--3.png
Normal file
After Width: | Height: | Size: 557 KiB |
BIN
docs/assets/prompt_syntax/apricots-0.png
Normal file
After Width: | Height: | Size: 571 KiB |
BIN
docs/assets/prompt_syntax/apricots-1.png
Normal file
After Width: | Height: | Size: 570 KiB |
BIN
docs/assets/prompt_syntax/apricots-2.png
Normal file
After Width: | Height: | Size: 568 KiB |
BIN
docs/assets/prompt_syntax/apricots-3.png
Normal file
After Width: | Height: | Size: 527 KiB |
BIN
docs/assets/prompt_syntax/apricots-4.png
Normal file
After Width: | Height: | Size: 489 KiB |
BIN
docs/assets/prompt_syntax/apricots-5.png
Normal file
After Width: | Height: | Size: 503 KiB |
BIN
docs/assets/prompt_syntax/mountain-man.png
Normal file
After Width: | Height: | Size: 488 KiB |
BIN
docs/assets/prompt_syntax/mountain-man1.png
Normal file
After Width: | Height: | Size: 499 KiB |
BIN
docs/assets/prompt_syntax/mountain-man2.png
Normal file
After Width: | Height: | Size: 524 KiB |
BIN
docs/assets/prompt_syntax/mountain-man3.png
Normal file
After Width: | Height: | Size: 593 KiB |
BIN
docs/assets/prompt_syntax/mountain-man4.png
Normal file
After Width: | Height: | Size: 598 KiB |
BIN
docs/assets/prompt_syntax/mountain1-man.png
Normal file
After Width: | Height: | Size: 488 KiB |
BIN
docs/assets/prompt_syntax/mountain2-man.png
Normal file
After Width: | Height: | Size: 487 KiB |
BIN
docs/assets/prompt_syntax/mountain3-man.png
Normal file
After Width: | Height: | Size: 489 KiB |
@ -153,6 +153,7 @@ Here are the invoke> command that apply to txt2img:
|
||||
| --cfg_scale <float>| -C<float> | 7.5 | How hard to try to match the prompt to the generated image; any number greater than 1.0 works, but the useful range is roughly 5.0 to 20.0 |
|
||||
| --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.|
|
||||
| --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. |
|
||||
| --karras_max <int> | | 29 | When using k_* samplers, set the maximum number of steps before shifting from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts) This value is sticky. [29] |
|
||||
| --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution |
|
||||
| --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
||||
| --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt |
|
||||
@ -218,8 +219,13 @@ well as the --mask (-M) and --text_mask (-tm) arguments:
|
||||
| Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|
||||
|--------------------|------------|---------------------|--------------|
|
||||
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
|
||||
| `--invert_mask ` | | False |If true, invert the mask so that transparent areas are opaque and vice versa.|
|
||||
| `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image|
|
||||
|
||||
The mask may either be an image with transparent areas, in which case
|
||||
the inpainting will occur in the transparent areas only, or a black
|
||||
and white image, in which case all black areas will be painted into.
|
||||
|
||||
`--text_mask` (short form `-tm`) is a way to generate a mask using a
|
||||
text description of the part of the image to replace. For example, if
|
||||
you have an image of a breakfast plate with a bagel, toast and
|
||||
|
@ -121,8 +121,6 @@ Both of the outputs look kind of like what I was thinking of. With the strength
|
||||
|
||||
If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `"fire"`:
|
||||
|
||||
If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `fire`:
|
||||
|
||||
```commandline
|
||||
invoke> "fire" -s10 -W384 -H384 -S1592514025 -I /tmp/fire-drawing.png --strength 0.7
|
||||
```
|
||||
|
@ -34,6 +34,16 @@ original unedited image and the masked (partially transparent) image:
|
||||
invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png
|
||||
```
|
||||
|
||||
If you are using Photoshop to make your transparent masks, here is a
|
||||
protocol contributed by III_Communication36 (Discord name):
|
||||
|
||||
Create your alpha channel for mask in photoshop, then run
|
||||
image/adjust/threshold on that channel. Export as Save a copy using
|
||||
superpng (3rd party free download plugin) making sure alpha channel
|
||||
is selected. Then masking works as it should for the img2img
|
||||
process 100%. Can feed just one image this way without needing to
|
||||
feed the -M mask behind it
|
||||
|
||||
## **Masking using Text**
|
||||
|
||||
You can also create a mask using a text prompt to select the part of
|
||||
@ -139,7 +149,83 @@ region directly:
|
||||
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
|
||||
```
|
||||
|
||||
### Inpainting is not changing the masked region enough!
|
||||
## Using the RunwayML inpainting model
|
||||
|
||||
The [RunwayML Inpainting Model
|
||||
v1.5](https://huggingface.co/runwayml/stable-diffusion-inpainting) is
|
||||
a specialized version of [Stable Diffusion
|
||||
v1.5](https://huggingface.co/spaces/runwayml/stable-diffusion-v1-5)
|
||||
that contains extra channels specifically designed to enhance
|
||||
inpainting and outpainting. While it can do regular `txt2img` and
|
||||
`img2img`, it really shines when filling in missing regions. It has an
|
||||
almost uncanny ability to blend the new regions with existing ones in
|
||||
a semantically coherent way.
|
||||
|
||||
To install the inpainting model, follow the
|
||||
[instructions](INSTALLING-MODELS.md) for installing a new model. You
|
||||
may use either the CLI (`invoke.py` script) or directly edit the
|
||||
`configs/models.yaml` configuration file to do this. The main thing to
|
||||
watch out for is that the the model `config` option must be set up to
|
||||
use `v1-inpainting-inference.yaml` rather than the `v1-inference.yaml`
|
||||
file that is used by Stable Diffusion 1.4 and 1.5.
|
||||
|
||||
After installation, your `models.yaml` should contain an entry that
|
||||
looks like this one:
|
||||
|
||||
inpainting-1.5:
|
||||
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
description: SD inpainting v1.5
|
||||
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
|
||||
As shown in the example, you may include a VAE fine-tuning weights
|
||||
file as well. This is strongly recommended.
|
||||
|
||||
To use the custom inpainting model, launch `invoke.py` with the
|
||||
argument `--model inpainting-1.5` or alternatively from within the
|
||||
script use the `!switch inpainting-1.5` command to load and switch to
|
||||
the inpainting model.
|
||||
|
||||
You can now do inpainting and outpainting exactly as described above,
|
||||
but there will (likely) be a noticeable improvement in
|
||||
coherence. Txt2img and Img2img will work as well.
|
||||
|
||||
There are a few caveats to be aware of:
|
||||
|
||||
1. The inpainting model is larger than the standard model, and will
|
||||
use nearly 4 GB of GPU VRAM. This makes it unlikely to run on
|
||||
a 4 GB graphics card.
|
||||
|
||||
2. When operating in Img2img mode, the inpainting model is much less
|
||||
steerable than the standard model. It is great for making small
|
||||
changes, such as changing the pattern of a fabric, or slightly
|
||||
changing a subject's expression or hair, but the model will
|
||||
resist making the dramatic alterations that the standard
|
||||
model lets you do.
|
||||
|
||||
3. While the `--hires` option works fine with the inpainting model,
|
||||
some special features, such as `--embiggen` are disabled.
|
||||
|
||||
4. Prompt weighting (`banana++ sushi`) and merging work well with
|
||||
the inpainting model, but prompt swapping (a ("fluffy cat").swap("smiling dog") eating a hotdog`)
|
||||
will not have any effect due to the way the model is set up.
|
||||
You may use text masking (with `-tm thing-to-mask`) as an
|
||||
effective replacement.
|
||||
|
||||
5. The model tends to oversharpen image if you use high step or CFG
|
||||
values. If you need to do large steps, use the standard model.
|
||||
|
||||
6. The `--strength` (`-f`) option has no effect on the inpainting
|
||||
model due to its fundamental differences with the standard
|
||||
model. It will always take the full number of steps you specify.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Here are some troubleshooting tips for inpainting and outpainting.
|
||||
|
||||
## Inpainting is not changing the masked region enough!
|
||||
|
||||
One of the things to understand about how inpainting works is that it
|
||||
is equivalent to running img2img on just the masked (transparent)
|
||||
|
@ -15,13 +15,52 @@ InvokeAI supports two versions of outpainting, one called "outpaint"
|
||||
and the other "outcrop." They work slightly differently and each has
|
||||
its advantages and drawbacks.
|
||||
|
||||
### Outpainting
|
||||
|
||||
Outpainting is the same as inpainting, except that the painting occurs
|
||||
in the regions outside of the original image. To outpaint using the
|
||||
`invoke.py` command line script, prepare an image in which the borders
|
||||
to be extended are pure black. Add an alpha channel (if there isn't one
|
||||
already), and make the borders completely transparent and the interior
|
||||
completely opaque. If you wish to modify the interior as well, you may
|
||||
create transparent holes in the transparency layer, which `img2img` will
|
||||
paint into as usual.
|
||||
|
||||
Pass the image as the argument to the `-I` switch as you would for
|
||||
regular inpainting:
|
||||
|
||||
invoke> a stream by a river -I /path/to/transparent_img.png
|
||||
|
||||
You'll likely be delighted by the results.
|
||||
|
||||
### Tips
|
||||
|
||||
1. Do not try to expand the image too much at once. Generally it is best
|
||||
to expand the margins in 64-pixel increments. 128 pixels often works,
|
||||
but your mileage may vary depending on the nature of the image you are
|
||||
trying to outpaint into.
|
||||
|
||||
2. There are a series of switches that can be used to adjust how the
|
||||
inpainting algorithm operates. In particular, you can use these to
|
||||
minimize the seam that sometimes appears between the original image
|
||||
and the extended part. These switches are:
|
||||
|
||||
--seam_size SEAM_SIZE Size of the mask around the seam between original and outpainted image (0)
|
||||
--seam_blur SEAM_BLUR The amount to blur the seam inwards (0)
|
||||
--seam_strength STRENGTH The img2img strength to use when filling the seam (0.7)
|
||||
--seam_steps SEAM_STEPS The number of steps to use to fill the seam. (10)
|
||||
--tile_size TILE_SIZE The tile size to use for filling outpaint areas (32)
|
||||
|
||||
### Outcrop
|
||||
|
||||
The `outcrop` extension allows you to extend the image in 64 pixel
|
||||
increments in any dimension. You can apply the module to any image
|
||||
previously-generated by InvokeAI. Note that it will **not** work with
|
||||
arbitrary photographs or Stable Diffusion images created by other
|
||||
implementations.
|
||||
The `outcrop` extension gives you a convenient `!fix` postprocessing
|
||||
command that allows you to extend a previously-generated image in 64
|
||||
pixel increments in any direction. You can apply the module to any
|
||||
image previously-generated by InvokeAI. Note that it works with
|
||||
arbitrary PNG photographs, but not currently with JPG or other
|
||||
formats. Outcropping is particularly effective when combined with the
|
||||
[runwayML custom inpainting
|
||||
model](INPAINTING.md#using-the-runwayml-inpainting-model).
|
||||
|
||||
Consider this image:
|
||||
|
||||
@ -64,42 +103,3 @@ you'll get a slightly different result. You can run it repeatedly
|
||||
until you get an image you like. Unfortunately `!fix` does not
|
||||
currently respect the `-n` (`--iterations`) argument.
|
||||
|
||||
## Outpaint
|
||||
|
||||
The `outpaint` extension does the same thing, but with subtle
|
||||
differences. Starting with the same image, here is how we would add an
|
||||
additional 64 pixels to the top of the image:
|
||||
|
||||
```bash
|
||||
invoke> !fix images/curly.png --out_direction top 64
|
||||
```
|
||||
|
||||
(you can abbreviate `--out_direction` as `-D`.
|
||||
|
||||
The result is shown here:
|
||||
|
||||
<div align="center" markdown>
|
||||
![curly_woman_outpaint](../assets/outpainting/curly-outpaint.png)
|
||||
</div>
|
||||
|
||||
Although the effect is similar, there are significant differences from
|
||||
outcropping:
|
||||
|
||||
- You can only specify one direction to extend at a time.
|
||||
- The image is **not** resized. Instead, the image is shifted by the specified
|
||||
number of pixels. If you look carefully, you'll see that less of the lady's
|
||||
torso is visible in the image.
|
||||
- Because the image dimensions remain the same, there's no rounding
|
||||
to multiples of 64.
|
||||
- Attempting to outpaint larger areas will frequently give rise to ugly
|
||||
ghosting effects.
|
||||
- For best results, try increasing the step number.
|
||||
- If you don't specify a pixel value in `-D`, it will default to half
|
||||
of the whole image, which is likely not what you want.
|
||||
|
||||
!!! tip
|
||||
|
||||
Neither `outpaint` nor `outcrop` are perfect, but we continue to tune
|
||||
and improve them. If one doesn't work, try the other. You may also
|
||||
wish to experiment with other `img2img` arguments, such as `-C`, `-f`
|
||||
and `-s`.
|
||||
|
@ -45,7 +45,7 @@ Here's a prompt that depicts what it does.
|
||||
|
||||
original prompt:
|
||||
|
||||
`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
|
||||
`#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
|
||||
|
||||
<div align="center" markdown>
|
||||
![step1](../assets/negative_prompt_walkthru/step1.png)
|
||||
@ -84,6 +84,109 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h
|
||||
|
||||
---
|
||||
|
||||
## **Prompt Syntax Features**
|
||||
|
||||
The InvokeAI prompting language has the following features:
|
||||
|
||||
### Attention weighting
|
||||
Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term).
|
||||
|
||||
The following syntax is recognised:
|
||||
* single words without parentheses: `a tall thin man picking apricots+`
|
||||
* single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-`
|
||||
* more effect with more symbols `a tall thin man (picking apricots)++`
|
||||
* nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`)
|
||||
* all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.)
|
||||
* attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder
|
||||
|
||||
You can use this to increase or decrease the amount of something. Starting from this prompt of `a man picking apricots from a tree`, let's see what happens if we increase and decrease how much attention we want Stable Diffusion to pay to the word `apricots`:
|
||||
|
||||
![an AI generated image of a man picking apricots from a tree](../assets/prompt_syntax/apricots-0.png)
|
||||
|
||||
Using `-` to reduce apricot-ness:
|
||||
|
||||
| `a man picking apricots- from a tree` | `a man picking apricots-- from a tree` | `a man picking apricots--- from a tree` |
|
||||
| -- | -- | -- |
|
||||
| ![an AI generated image of a man picking apricots from a tree, with smaller apricots](../assets/prompt_syntax/apricots--1.png) | ![an AI generated image of a man picking apricots from a tree, with even smaller and fewer apricots](../assets/prompt_syntax/apricots--2.png) | ![an AI generated image of a man picking apricots from a tree, with very few very small apricots](../assets/prompt_syntax/apricots--3.png) |
|
||||
|
||||
Using `+` to increase apricot-ness:
|
||||
|
||||
| `a man picking apricots+ from a tree` | `a man picking apricots++ from a tree` | `a man picking apricots+++ from a tree` | `a man picking apricots++++ from a tree` | `a man picking apricots+++++ from a tree` |
|
||||
| -- | -- | -- | -- | -- |
|
||||
| ![an AI generated image of a man picking apricots from a tree, with larger, more vibrant apricots](../assets/prompt_syntax/apricots-1.png) | ![an AI generated image of a man picking apricots from a tree with even larger, even more vibrant apricots](../assets/prompt_syntax/apricots-2.png) | ![an AI generated image of a man picking apricots from a tree, but the man has been replaced by a pile of apricots](../assets/prompt_syntax/apricots-3.png) | ![an AI generated image of a man picking apricots from a tree, but the man has been replaced by a mound of giant melting-looking apricots](../assets/prompt_syntax/apricots-4.png) | ![an AI generated image of a man picking apricots from a tree, but the man and the leaves and parts of the ground have all been replaced by giant melting-looking apricots](../assets/prompt_syntax/apricots-5.png) |
|
||||
|
||||
You can also change the balance between different parts of a prompt. For example, below is a `mountain man`:
|
||||
|
||||
![an AI generated image of a mountain man](../assets/prompt_syntax/mountain-man.png)
|
||||
|
||||
And here he is with more mountain:
|
||||
|
||||
| `mountain+ man` | `mountain++ man` | `mountain+++ man` |
|
||||
| -- | -- | -- |
|
||||
| ![](../assets/prompt_syntax/mountain1-man.png) | ![](../assets/prompt_syntax/mountain2-man.png) | ![](../assets/prompt_syntax/mountain3-man.png) |
|
||||
|
||||
Or, alternatively, with more man:
|
||||
|
||||
| `mountain man+` | `mountain man++` | `mountain man+++` | `mountain man++++` |
|
||||
| -- | -- | -- | -- |
|
||||
| ![](../assets/prompt_syntax/mountain-man1.png) | ![](../assets/prompt_syntax/mountain-man2.png) | ![](../assets/prompt_syntax/mountain-man3.png) | ![](../assets/prompt_syntax/mountain-man4.png) |
|
||||
|
||||
### Blending between prompts
|
||||
|
||||
* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)`
|
||||
* The existing prompt blending using `:<weight>` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax.
|
||||
* Attention weights can be nested inside blends.
|
||||
* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output.
|
||||
|
||||
See the section below on "Prompt Blending" for more information about how this works.
|
||||
|
||||
### Cross-Attention Control ('prompt2prompt')
|
||||
|
||||
Sometimes an image you generate is almost right, and you just want to
|
||||
change one detail without affecting the rest. You could use a photo editor and inpainting
|
||||
to overpaint the area, but that's a pain. Here's where `prompt2prompt`
|
||||
comes in handy.
|
||||
|
||||
Generate an image with a given prompt, record the seed of the image,
|
||||
and then use the `prompt2prompt` syntax to substitute words in the
|
||||
original prompt for words in a new prompt. This works for `img2img` as well.
|
||||
|
||||
* `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
|
||||
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
|
||||
* for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`.
|
||||
* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand.
|
||||
* Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form.
|
||||
* The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image).
|
||||
* For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect.
|
||||
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
|
||||
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
|
||||
|
||||
|
||||
|
||||
The `prompt2prompt` code is based off [bloc97's
|
||||
colab](https://github.com/bloc97/CrossAttentionControl).
|
||||
|
||||
Note that `prompt2prompt` is not currently working with the runwayML
|
||||
inpainting model, and may never work due to the way this model is set
|
||||
up. If you attempt to use `prompt2prompt` you will get the original
|
||||
image back. However, since this model is so good at inpainting, a
|
||||
good substitute is to use the `clipseg` text masking option:
|
||||
|
||||
```
|
||||
invoke> a fluffy cat eating a hotdot
|
||||
Outputs:
|
||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||
```
|
||||
|
||||
### Escaping parantheses () and speech marks ""
|
||||
|
||||
If the model you are using has parentheses () or speech marks "" as
|
||||
part of its syntax, you will need to "escape" these using a backslash,
|
||||
so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt
|
||||
parser will attempt to interpret the parentheses as part of the prompt
|
||||
syntax and it will get confused.
|
||||
|
||||
## **Prompt Blending**
|
||||
|
||||
You may blend together different sections of the prompt to explore the
|
||||
|
@ -36,20 +36,6 @@ another environment with NVIDIA GPUs on-premises or in the cloud.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
#### Get the data files
|
||||
|
||||
Go to
|
||||
[Hugging Face](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original),
|
||||
and click "Access repository" to Download the model file `sd-v1-4.ckpt` (~4 GB)
|
||||
to `~/Downloads`. You'll need to create an account but it's quick and free.
|
||||
|
||||
Also download the face restoration model.
|
||||
|
||||
```Shell
|
||||
cd ~/Downloads
|
||||
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
|
||||
```
|
||||
|
||||
#### Install [Docker](https://github.com/santisbon/guides#docker)
|
||||
|
||||
On the Docker Desktop app, go to Preferences, Resources, Advanced. Increase the
|
||||
@ -57,86 +43,61 @@ CPUs and Memory to avoid this
|
||||
[Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to
|
||||
increase Swap and Disk image size too.
|
||||
|
||||
#### Get a Huggingface-Token
|
||||
|
||||
Go to [Hugging Face](https://huggingface.co/settings/tokens), create a token and
|
||||
temporary place it somewhere like a open texteditor window (but dont save it!,
|
||||
only keep it open, we need it in the next step)
|
||||
|
||||
### Setup
|
||||
|
||||
Set the fork you want to use and other variables.
|
||||
|
||||
```Shell
|
||||
TAG_STABLE_DIFFUSION="santisbon/stable-diffusion"
|
||||
PLATFORM="linux/arm64"
|
||||
GITHUB_STABLE_DIFFUSION="-b orig-gfpgan https://github.com/santisbon/stable-diffusion.git"
|
||||
REQS_STABLE_DIFFUSION="requirements-linux-arm64.txt"
|
||||
CONDA_SUBDIR="osx-arm64"
|
||||
!!! tip
|
||||
|
||||
echo $TAG_STABLE_DIFFUSION
|
||||
echo $PLATFORM
|
||||
echo $GITHUB_STABLE_DIFFUSION
|
||||
echo $REQS_STABLE_DIFFUSION
|
||||
echo $CONDA_SUBDIR
|
||||
I preffer to save my env vars
|
||||
in the repository root in a `.env` (or `.envrc`) file to automatically re-apply
|
||||
them when I come back.
|
||||
|
||||
The build- and run- scripts contain default values for almost everything,
|
||||
besides the [Hugging Face Token](https://huggingface.co/settings/tokens) you
|
||||
created in the last step.
|
||||
|
||||
Some Suggestions of variables you may want to change besides the Token:
|
||||
|
||||
| Environment-Variable | Description |
|
||||
| ------------------------------------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `HUGGINGFACE_TOKEN="hg_aewirhghlawrgkjbarug2"` | This is the only required variable, without you can't get the checkpoint |
|
||||
| `ARCH=aarch64` | if you are using a ARM based CPU |
|
||||
| `INVOKEAI_TAG=yourname/invokeai:latest` | the Container Repository / Tag which will be used |
|
||||
| `INVOKEAI_CONDA_ENV_FILE=environment-linux-aarch64.yml` | since environment.yml wouldn't work with aarch |
|
||||
| `INVOKEAI_GIT="-b branchname https://github.com/username/reponame"` | if you want to use your own fork |
|
||||
|
||||
#### Build the Image
|
||||
|
||||
I provided a build script, which is located in `docker-build/build.sh` but still
|
||||
needs to be executed from the Repository root.
|
||||
|
||||
```bash
|
||||
docker-build/build.sh
|
||||
```
|
||||
|
||||
Create a Docker volume for the downloaded model files.
|
||||
The build Script not only builds the container, but also creates the docker
|
||||
volume if not existing yet, or if empty it will just download the models. When
|
||||
it is done you can run the container via the run script
|
||||
|
||||
```Shell
|
||||
docker volume create my-vol
|
||||
```bash
|
||||
docker-build/run.sh
|
||||
```
|
||||
|
||||
Copy the data files to the Docker volume using a lightweight Linux container.
|
||||
We'll need the models at run time. You just need to create the container with
|
||||
the mountpoint; no need to run this dummy container.
|
||||
When used without arguments, the container will start the website and provide
|
||||
you the link to open it. But if you want to use some other parameters you can
|
||||
also do so.
|
||||
|
||||
```Shell
|
||||
cd ~/Downloads # or wherever you saved the files
|
||||
!!! warning "Deprecated"
|
||||
|
||||
docker create --platform $PLATFORM --name dummy --mount source=my-vol,target=/data alpine
|
||||
|
||||
docker cp sd-v1-4.ckpt dummy:/data
|
||||
docker cp GFPGANv1.4.pth dummy:/data
|
||||
```
|
||||
|
||||
Get the repo and download the Miniconda installer (we'll need it at build time).
|
||||
Replace the URL with the version matching your container OS and the architecture
|
||||
it will run on.
|
||||
|
||||
```Shell
|
||||
cd ~
|
||||
git clone $GITHUB_STABLE_DIFFUSION
|
||||
|
||||
cd stable-diffusion/docker-build
|
||||
chmod +x entrypoint.sh
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O anaconda.sh && chmod +x anaconda.sh
|
||||
```
|
||||
|
||||
Build the Docker image. Give it any tag `-t` that you want.
|
||||
Choose the Linux container's host platform: x86-64/Intel is `amd64`. Apple
|
||||
silicon is `arm64`. If deploying the container to the cloud to leverage powerful
|
||||
GPU instances you'll be on amd64 hardware but if you're just trying this out
|
||||
locally on Apple silicon choose arm64.
|
||||
The application uses libraries that need to match the host environment so use
|
||||
the appropriate requirements file.
|
||||
Tip: Check that your shell session has the env variables set above.
|
||||
|
||||
```Shell
|
||||
docker build -t $TAG_STABLE_DIFFUSION \
|
||||
--platform $PLATFORM \
|
||||
--build-arg gsd=$GITHUB_STABLE_DIFFUSION \
|
||||
--build-arg rsd=$REQS_STABLE_DIFFUSION \
|
||||
--build-arg cs=$CONDA_SUBDIR \
|
||||
.
|
||||
```
|
||||
|
||||
Run a container using your built image.
|
||||
Tip: Make sure you've created and populated the Docker volume (above).
|
||||
|
||||
```Shell
|
||||
docker run -it \
|
||||
--rm \
|
||||
--platform $PLATFORM \
|
||||
--name stable-diffusion \
|
||||
--hostname stable-diffusion \
|
||||
--mount source=my-vol,target=/data \
|
||||
$TAG_STABLE_DIFFUSION
|
||||
```
|
||||
From here on it is the rest of the previous Docker-Docs, which will still
|
||||
provide usefull informations for one or the other.
|
||||
|
||||
## Usage (time to have fun)
|
||||
|
||||
@ -240,7 +201,8 @@ server with:
|
||||
python3 scripts/invoke.py --full_precision --web
|
||||
```
|
||||
|
||||
If it's running on your Mac point your Mac web browser to http://127.0.0.1:9090
|
||||
If it's running on your Mac point your Mac web browser to
|
||||
<http://127.0.0.1:9090>
|
||||
|
||||
Press Control-C at the command line to stop the web server.
|
||||
|
||||
|
44
environment-linux-aarch64.yml
Normal file
@ -0,0 +1,44 @@
|
||||
name: invokeai
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python>=3.9
|
||||
- pip>=20.3
|
||||
- cudatoolkit
|
||||
- pytorch
|
||||
- torchvision
|
||||
- numpy=1.19
|
||||
- imageio=2.9.0
|
||||
- opencv=4.6.0
|
||||
- pillow=8.*
|
||||
- flask=2.1.*
|
||||
- flask_cors=3.0.10
|
||||
- flask-socketio=5.3.0
|
||||
- send2trash=1.8.0
|
||||
- eventlet
|
||||
- albumentations=0.4.3
|
||||
- pudb=2019.2
|
||||
- imageio-ffmpeg=0.4.2
|
||||
- pytorch-lightning=1.7.7
|
||||
- streamlit
|
||||
- einops=0.3.0
|
||||
- kornia=0.6
|
||||
- torchmetrics=0.7.0
|
||||
- transformers=4.21.3
|
||||
- torch-fidelity=0.3.0
|
||||
- tokenizers>=0.11.1,!=0.11.3,<0.13
|
||||
- pip:
|
||||
- omegaconf==2.1.1
|
||||
- realesrgan==0.2.5.0
|
||||
- test-tube>=0.7.5
|
||||
- pyreadline3
|
||||
- dependency_injector==4.40.0
|
||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||
- -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- -e .
|
||||
variables:
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 1
|
126
ldm/generate.py
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
import pyparsing
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
@ -24,6 +24,7 @@ from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
@ -179,6 +180,7 @@ class Generate:
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
self.karras_max = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -269,10 +271,12 @@ class Generate:
|
||||
variation_amount = 0.0,
|
||||
threshold = 0.0,
|
||||
perlin = 0.0,
|
||||
karras_max = None,
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
text_mask = None,
|
||||
invert_mask = False,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_color = None,
|
||||
@ -293,6 +297,13 @@ class Generate:
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
use_mps_noise = False,
|
||||
# Seam settings for outpainting
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
force_outpaint: bool = False,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
@ -310,6 +321,7 @@ class Generate:
|
||||
init_img // path to an initial image
|
||||
init_mask // path to a mask for the initial image
|
||||
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
||||
invert_mask // boolean, if true invert the mask
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
@ -350,7 +362,8 @@ class Generate:
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
self.step_callback = step_callback
|
||||
self.step_callback = step_callback
|
||||
self.karras_max = karras_max
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
@ -395,6 +408,11 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
if karras_max is not None and isinstance(self.sampler,KSampler):
|
||||
self.sampler.adjust_settings(karras_max=karras_max)
|
||||
|
||||
tic = time.time()
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@ -404,7 +422,7 @@ class Generate:
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
@ -417,19 +435,12 @@ class Generate:
|
||||
height,
|
||||
fit=fit,
|
||||
text_mask=text_mask,
|
||||
invert_mask=invert_mask,
|
||||
force_outpaint=force_outpaint,
|
||||
)
|
||||
|
||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||
if (init_image is not None) and (mask_image is not None):
|
||||
generator = self._make_inpaint()
|
||||
elif (embiggen != None or embiggen_tiles != None):
|
||||
generator = self._make_embiggen()
|
||||
elif init_image is not None:
|
||||
generator = self._make_img2img()
|
||||
elif hires_fix:
|
||||
generator = self._make_txt2img2img()
|
||||
else:
|
||||
generator = self._make_txt2img()
|
||||
generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
|
||||
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations
|
||||
@ -448,7 +459,7 @@ class Generate:
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c),
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
@ -464,7 +475,13 @@ class Generate:
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
mask_blur_radius=mask_blur_radius,
|
||||
safety_checker=checker
|
||||
safety_checker=checker,
|
||||
seam_size = seam_size,
|
||||
seam_blur = seam_blur,
|
||||
seam_strength = seam_strength,
|
||||
seam_steps = seam_steps,
|
||||
tile_size = tile_size,
|
||||
force_outpaint = force_outpaint
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@ -481,14 +498,14 @@ class Generate:
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
except KeyboardInterrupt:
|
||||
if catch_interrupts:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
@ -545,7 +562,7 @@ class Generate:
|
||||
# try to reuse the same filename prefix as the original file.
|
||||
# we take everything up to the first period
|
||||
prefix = None
|
||||
m = re.match('^([^.]+)\.',os.path.basename(image_path))
|
||||
m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
|
||||
if m:
|
||||
prefix = m.groups()[0]
|
||||
|
||||
@ -553,7 +570,8 @@ class Generate:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# used by multiple postfixers
|
||||
uc, c = get_uc_and_c(
|
||||
# todo: cross-attention control
|
||||
uc, c, _ = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
@ -598,10 +616,9 @@ class Generate:
|
||||
|
||||
elif tool == 'embiggen':
|
||||
# fetch the metadata from the image
|
||||
generator = self._make_embiggen()
|
||||
generator = self.select_generator(embiggen=True)
|
||||
opt.strength = 0.40
|
||||
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
|
||||
# embiggen takes a image path (sigh)
|
||||
generator.generate(
|
||||
prompt,
|
||||
sampler = self.sampler,
|
||||
@ -635,6 +652,32 @@ class Generate:
|
||||
print(f'* postprocessing tool {tool} is not yet supported')
|
||||
return None
|
||||
|
||||
def select_generator(
|
||||
self,
|
||||
init_image:Image.Image=None,
|
||||
mask_image:Image.Image=None,
|
||||
embiggen:bool=False,
|
||||
hires_fix:bool=False,
|
||||
force_outpaint:bool=False,
|
||||
):
|
||||
inpainting_model_in_use = self.sampler.uses_inpainting_model()
|
||||
|
||||
if hires_fix:
|
||||
return self._make_txt2img2img()
|
||||
|
||||
if embiggen is not None:
|
||||
return self._make_embiggen()
|
||||
|
||||
if inpainting_model_in_use:
|
||||
return self._make_omnibus()
|
||||
|
||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||
return self._make_inpaint()
|
||||
|
||||
if init_image is not None:
|
||||
return self._make_img2img()
|
||||
|
||||
return self._make_txt2img()
|
||||
|
||||
def _make_images(
|
||||
self,
|
||||
@ -644,6 +687,8 @@ class Generate:
|
||||
height,
|
||||
fit=False,
|
||||
text_mask=None,
|
||||
invert_mask=False,
|
||||
force_outpaint=False,
|
||||
):
|
||||
init_image = None
|
||||
init_mask = None
|
||||
@ -657,7 +702,7 @@ class Generate:
|
||||
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
if self._has_transparency(image):
|
||||
self._transparency_check_and_warning(image, mask)
|
||||
self._transparency_check_and_warning(image, mask, force_outpaint)
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
@ -673,8 +718,12 @@ class Generate:
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
if invert_mask:
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
|
||||
return init_image,init_mask
|
||||
|
||||
# lots o' repeated code here! Turn into a make_func()
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
from ldm.invoke.generator import Generator
|
||||
@ -685,6 +734,7 @@ class Generate:
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_embiggen(self):
|
||||
@ -713,6 +763,15 @@ class Generate:
|
||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||
return self.generators['inpaint']
|
||||
|
||||
# "omnibus" supports the runwayML custom inpainting model, which does
|
||||
# txt2img, img2img and inpainting using slight variations on the same code
|
||||
def _make_omnibus(self):
|
||||
if not self.generators.get('omnibus'):
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['omnibus']
|
||||
|
||||
def load_model(self):
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
@ -839,6 +898,8 @@ class Generate:
|
||||
def sample_to_image(self, samples):
|
||||
return self._make_base().sample_to_image(samples)
|
||||
|
||||
# very repetitive code - can this be simplified? The KSampler names are
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
@ -846,15 +907,11 @@ class Generate:
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
@ -888,8 +945,9 @@ class Generate:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
return image
|
||||
|
||||
def _create_init_image(self, image, width, height, fit=True):
|
||||
image = image.convert('RGB')
|
||||
def _create_init_image(self, image: Image.Image, width, height, fit=True):
|
||||
if image.mode != 'RGBA':
|
||||
image = image.convert('RGBA')
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
@ -954,11 +1012,11 @@ class Generate:
|
||||
colored += 1
|
||||
return colored == 0
|
||||
|
||||
def _transparency_check_and_warning(self,image, mask):
|
||||
def _transparency_check_and_warning(self,image, mask, force_outpaint=False):
|
||||
if not mask:
|
||||
print(
|
||||
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||
if self._check_for_erasure(image):
|
||||
if (not force_outpaint) and self._check_for_erasure(image):
|
||||
print(
|
||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||
|
@ -83,16 +83,16 @@ with metadata_from_png():
|
||||
import argparse
|
||||
from argparse import Namespace, RawTextHelpFormatter
|
||||
import pydoc
|
||||
import shlex
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import copy
|
||||
import base64
|
||||
import functools
|
||||
import ldm.invoke.pngwriter
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
'ddim',
|
||||
@ -169,28 +169,31 @@ class Args(object):
|
||||
|
||||
def parse_cmd(self,cmd_string):
|
||||
'''Parse a invoke>-style command string '''
|
||||
command = cmd_string.replace("'", "\\'")
|
||||
try:
|
||||
elements = shlex.split(command)
|
||||
elements = [x.replace("\\'","'") for x in elements]
|
||||
except ValueError:
|
||||
import sys, traceback
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return
|
||||
switches = ['']
|
||||
switches_started = False
|
||||
|
||||
for element in elements:
|
||||
if element[0] == '-' and not switches_started:
|
||||
switches_started = True
|
||||
if switches_started:
|
||||
switches.append(element)
|
||||
# handle the case in which the first token is a switch
|
||||
if cmd_string.startswith('-'):
|
||||
prompt = ''
|
||||
switches = cmd_string
|
||||
# handle the case in which the prompt is enclosed by quotes
|
||||
elif cmd_string.startswith('"'):
|
||||
a = shlex.split(cmd_string,comments=True)
|
||||
prompt = a[0]
|
||||
switches = shlex.join(a[1:])
|
||||
else:
|
||||
# no initial quote, so get everything up to the first thing
|
||||
# that looks like a switch
|
||||
if cmd_string.startswith('-'):
|
||||
prompt = ''
|
||||
switches = cmd_string
|
||||
else:
|
||||
switches[0] += element
|
||||
switches[0] += ' '
|
||||
switches[0] = switches[0][: len(switches[0]) - 1]
|
||||
match = re.match('^(.+?)\s(--?[a-zA-Z].+)',cmd_string)
|
||||
if match:
|
||||
prompt,switches = match.groups()
|
||||
else:
|
||||
prompt = cmd_string
|
||||
switches = ''
|
||||
try:
|
||||
self._cmd_switches = self._cmd_parser.parse_args(switches)
|
||||
self._cmd_switches = self._cmd_parser.parse_args(shlex.split(switches,comments=True))
|
||||
setattr(self._cmd_switches,'prompt',prompt)
|
||||
return self._cmd_switches
|
||||
except:
|
||||
return None
|
||||
@ -211,12 +214,16 @@ class Args(object):
|
||||
a = vars(self)
|
||||
a.update(kwargs)
|
||||
switches = list()
|
||||
switches.append(f'"{a["prompt"]}"')
|
||||
prompt = a['prompt']
|
||||
prompt.replace('"','\\"')
|
||||
switches.append(prompt)
|
||||
switches.append(f'-s {a["steps"]}')
|
||||
switches.append(f'-S {a["seed"]}')
|
||||
switches.append(f'-W {a["width"]}')
|
||||
switches.append(f'-H {a["height"]}')
|
||||
switches.append(f'-C {a["cfg_scale"]}')
|
||||
if a['karras_max'] is not None:
|
||||
switches.append(f'--karras_max {a["karras_max"]}')
|
||||
if a['perlin'] > 0:
|
||||
switches.append(f'--perlin {a["perlin"]}')
|
||||
if a['threshold'] > 0:
|
||||
@ -568,10 +575,17 @@ class Args(object):
|
||||
)
|
||||
render_group = parser.add_argument_group('General rendering')
|
||||
img2img_group = parser.add_argument_group('Image-to-image and inpainting')
|
||||
inpainting_group = parser.add_argument_group('Inpainting')
|
||||
outpainting_group = parser.add_argument_group('Outpainting and outcropping')
|
||||
variation_group = parser.add_argument_group('Creating and combining variations')
|
||||
postprocessing_group = parser.add_argument_group('Post-processing')
|
||||
special_effects_group = parser.add_argument_group('Special effects')
|
||||
render_group.add_argument('prompt')
|
||||
deprecated_group = parser.add_argument_group('Deprecated options')
|
||||
render_group.add_argument(
|
||||
'--prompt',
|
||||
default='',
|
||||
help='prompt string',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-s',
|
||||
'--steps',
|
||||
@ -689,7 +703,13 @@ class Args(object):
|
||||
default=6,
|
||||
choices=range(0,10),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). [6]'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--karras_max',
|
||||
type=int,
|
||||
default=None,
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-I',
|
||||
@ -697,12 +717,6 @@ class Args(object):
|
||||
type=str,
|
||||
help='Path to input image for img2img mode (supersedes width and height)',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-M',
|
||||
'--init_mask',
|
||||
type=str,
|
||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-tm',
|
||||
'--text_mask',
|
||||
@ -730,29 +744,68 @@ class Args(object):
|
||||
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
default=0.75,
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-D',
|
||||
'--out_direction',
|
||||
nargs='+',
|
||||
inpainting_group.add_argument(
|
||||
'-M',
|
||||
'--init_mask',
|
||||
type=str,
|
||||
metavar=('direction', 'pixels'),
|
||||
help='Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
|
||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-c',
|
||||
'--outcrop',
|
||||
nargs='+',
|
||||
type=str,
|
||||
metavar=('direction','pixels'),
|
||||
help='Outcrop the image with one or more direction/pixel pairs: -c top 64 bottom 128 left 64 right 64',
|
||||
inpainting_group.add_argument(
|
||||
'--invert_mask',
|
||||
action='store_true',
|
||||
help='Invert the mask',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
inpainting_group.add_argument(
|
||||
'-r',
|
||||
'--inpaint_replace',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'-c',
|
||||
'--outcrop',
|
||||
nargs='+',
|
||||
type=str,
|
||||
metavar=('direction','pixels'),
|
||||
help='Outcrop the image with one or more direction/pixel pairs: e.g. -c top 64 bottom 128 left 64 right 64',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'--force_outpaint',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Force outpainting if you have no inpainting mask to pass',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'--seam_size',
|
||||
type=int,
|
||||
default=0,
|
||||
help='When outpainting, size of the mask around the seam between original and outpainted image',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'--seam_blur',
|
||||
type=int,
|
||||
default=0,
|
||||
help='When outpainting, the amount to blur the seam inwards',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'--seam_strength',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='When outpainting, the img2img strength to use when filling the seam. Values around 0.7 work well',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'--seam_steps',
|
||||
type=int,
|
||||
default=10,
|
||||
help='When outpainting, the number of steps to use to fill the seam. Low values (~10) work well',
|
||||
)
|
||||
outpainting_group.add_argument(
|
||||
'--tile_size',
|
||||
type=int,
|
||||
default=32,
|
||||
help='When outpainting, the tile size to use for filling outpaint areas',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-ft',
|
||||
'--facetool',
|
||||
@ -836,7 +889,14 @@ class Args(object):
|
||||
dest='use_mps_noise',
|
||||
help='Simulate noise on M1 systems to get the same results'
|
||||
)
|
||||
|
||||
deprecated_group.add_argument(
|
||||
'-D',
|
||||
'--out_direction',
|
||||
nargs='+',
|
||||
type=str,
|
||||
metavar=('direction', 'pixels'),
|
||||
help='Older outcropping system. Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
|
||||
)
|
||||
return parser
|
||||
|
||||
def format_metadata(**kwargs):
|
||||
@ -872,7 +932,7 @@ def metadata_dumps(opt,
|
||||
|
||||
# remove any image keys not mentioned in RFC #266
|
||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||
'cfg_scale','threshold','perlin','fnformat', 'step_number','width','height','extra','strength',
|
||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
|
||||
'init_img','init_mask','facetool','facetool_strength','upscale']
|
||||
rfc_dict ={}
|
||||
|
||||
@ -923,6 +983,23 @@ def metadata_dumps(opt,
|
||||
|
||||
return metadata
|
||||
|
||||
@functools.lru_cache(maxsize=50)
|
||||
def args_from_png(png_file_path) -> list[Args]:
|
||||
'''
|
||||
Given the path to a PNG file created by invoke.py,
|
||||
retrieves a list of Args objects containing the image
|
||||
data.
|
||||
'''
|
||||
try:
|
||||
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
|
||||
except AttributeError:
|
||||
return [legacy_metadata_load({},png_file_path)]
|
||||
|
||||
try:
|
||||
return metadata_loads(meta)
|
||||
except:
|
||||
return [legacy_metadata_load(meta,png_file_path)]
|
||||
|
||||
@functools.lru_cache(maxsize=50)
|
||||
def metadata_from_png(png_file_path) -> Args:
|
||||
'''
|
||||
@ -930,11 +1007,8 @@ def metadata_from_png(png_file_path) -> Args:
|
||||
an Args object containing the image metadata. Note that this
|
||||
returns a single Args object, not multiple.
|
||||
'''
|
||||
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
|
||||
if 'sd-metadata' in meta and len(meta['sd-metadata'])>0 :
|
||||
return metadata_loads(meta)[0]
|
||||
else:
|
||||
return legacy_metadata_load(meta,png_file_path)
|
||||
args_list = args_from_png(png_file_path)
|
||||
return args_list[0]
|
||||
|
||||
def dream_cmd_from_png(png_file_path):
|
||||
opt = metadata_from_png(png_file_path)
|
||||
@ -949,7 +1023,7 @@ def metadata_loads(metadata) -> list:
|
||||
'''
|
||||
results = []
|
||||
try:
|
||||
if 'grid' in metadata['sd-metadata']:
|
||||
if 'images' in metadata['sd-metadata']:
|
||||
images = metadata['sd-metadata']['images']
|
||||
else:
|
||||
images = [metadata['sd-metadata']['image']]
|
||||
|
@ -4,107 +4,191 @@ weighted subprompts.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c() get the conditioned and unconditioned latent
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
|
||||
'''
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
|
||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
||||
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditionals = re.findall(unconditional_regex, prompt)
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
|
||||
prompt = re.sub(' +', ' ', clean_prompt)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
pp = PromptParser()
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = split_weighted_subprompts(
|
||||
prompt, skip_normalize
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
||||
if legacy_blend is not None:
|
||||
parsed_prompt = legacy_blend
|
||||
else:
|
||||
# we don't support conjunctions for now
|
||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||
print(f">> Parsed prompt to {parsed_prompt}")
|
||||
|
||||
conditioning = None
|
||||
cac_args:CrossAttentionControl.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for flattened_prompt in blend.prompts:
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_opcodes = []
|
||||
edit_options = []
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = CrossAttentionControl.Arguments(
|
||||
edited_conditioning = edited_conditioning,
|
||||
edit_opcodes = edit_opcodes,
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
||||
if isinstance(conditioning, dict):
|
||||
# hybrid conditioning is in play
|
||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
if cac_args is not None:
|
||||
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
cac_args = None
|
||||
|
||||
return (
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
|
||||
if len(weighted_subprompts) > 1:
|
||||
# i dont know if this is correct.. but it works
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
log_tokenization(subprompt, model, log_tokens, weight)
|
||||
c = torch.add(
|
||||
c,
|
||||
model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
log_tokenization(prompt, model, log_tokens, 1)
|
||||
c = model.get_learned_conditioning([prompt])
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
return (uc, c)
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
original_tokens = original_tokens.cpu().numpy()[0]
|
||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False, weight=1):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
if not flattened_prompt.is_empty and log_tokens:
|
||||
start_token = model.cond_stage_model.tokenizer.bos_token_id
|
||||
end_token = model.cond_stage_model.tokenizer.eos_token_id
|
||||
tokens_list = tokens[0].tolist()
|
||||
if tokens_list[0] == start_token:
|
||||
tokens_list[0] = '<start>'
|
||||
try:
|
||||
first_end_token_index = tokens_list.index(end_token)
|
||||
tokens_list[first_end_token_index] = '<end>'
|
||||
tokens_list = tokens_list[:first_end_token_index+1]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
def get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
def flatten_hybrid_conditioning(uncond, cond):
|
||||
'''
|
||||
This handles the choice between a conditional conditioning
|
||||
that is a tensor (used by cross attention) vs one that has additional
|
||||
dimensions as well, as used by 'hybrid'
|
||||
'''
|
||||
assert isinstance(uncond, dict)
|
||||
assert isinstance(cond, dict)
|
||||
cond_flattened = dict()
|
||||
for k in cond:
|
||||
if isinstance(cond[k], list):
|
||||
cond_flattened[k] = [
|
||||
torch.cat([uncond[k][i], cond[k][i]])
|
||||
for i in range(len(cond[k]))
|
||||
]
|
||||
else:
|
||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||
return uncond, cond_flattened
|
||||
|
||||
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import traceback
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image, ImageFilter
|
||||
from einops import rearrange, repeat
|
||||
@ -28,7 +29,8 @@ class Generator():
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
@ -43,14 +45,15 @@ class Generator():
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
make_image = self.get_make_image(
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
@ -59,12 +62,14 @@ class Generator():
|
||||
perlin = perlin,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
results = []
|
||||
seed = seed if seed is not None else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
with scope(self.model.device.type), self.model.ema_scope():
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
@ -79,7 +84,8 @@ class Generator():
|
||||
try:
|
||||
x_T = self.get_noise(width,height)
|
||||
except:
|
||||
pass
|
||||
print('** An error occurred while getting initial noise **')
|
||||
print(traceback.format_exc())
|
||||
|
||||
image = make_image(x_T)
|
||||
|
||||
@ -95,10 +101,10 @@ class Generator():
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples):
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
@ -21,6 +21,7 @@ class Embiggen(Generator):
|
||||
def generate(self,prompt,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
@ -63,6 +64,8 @@ class Embiggen(Generator):
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
|
@ -10,11 +10,12 @@ from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
@ -29,7 +30,7 @@ class Img2Img(Generator):
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
@ -38,7 +39,7 @@ class Img2Img(Generator):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
@ -55,7 +56,9 @@ class Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
@ -77,7 +80,10 @@ class Img2Img(Generator):
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
|
@ -2,12 +2,13 @@
|
||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import PIL
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
@ -24,11 +25,128 @@ class Inpaint(Img2Img):
|
||||
self.mask_blur_radius = 0
|
||||
super().__init__(model, precision)
|
||||
|
||||
# Outpaint support code
|
||||
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False
|
||||
)
|
||||
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:,:,:,:,3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = (tiles_mask > 0)
|
||||
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed = seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1,2)
|
||||
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||
si = Image.fromarray(st, mode='RGBA')
|
||||
|
||||
return si
|
||||
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
||||
|
||||
# Combine
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
||||
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self,
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,strength,
|
||||
noise
|
||||
) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0
|
||||
)
|
||||
|
||||
result = make_image(noise)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
mask_blur_radius: int = 8,
|
||||
step_callback=None,inpaint_replace=False, **kwargs):
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
@ -37,7 +155,17 @@ class Inpaint(Img2Img):
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
# Fill missing areas of original image
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(),
|
||||
seed = self.seed,
|
||||
tile_size = tile_size
|
||||
)
|
||||
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||
|
||||
# Create init tensor
|
||||
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
@ -73,7 +201,8 @@ class Inpaint(Img2Img):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@ -105,38 +234,56 @@ class Inpaint(Img2Img):
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
result = self.sample_to_image(samples)
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
result = self.seam_paint(
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
seam_strength,
|
||||
x_T)
|
||||
|
||||
return result
|
||||
|
||||
return make_image
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
pil_mask = self.pil_mask
|
||||
pil_image = self.pil_image
|
||||
mask_blur_radius = self.mask_blur_radius
|
||||
|
||||
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
|
||||
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
|
||||
pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||
# histogram and cause slight color changes.
|
||||
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||
init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version
|
||||
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(image, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
@ -149,6 +296,16 @@ class Inpaint(Img2Img):
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||
matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
|
||||
return corrected_result
|
||||
|
153
ldm/invoke/generator/omnibus.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from PIL import Image, ImageOps
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
|
||||
class Omnibus(Img2Img,Txt2Img):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
init_image = None,
|
||||
mask_image = None,
|
||||
strength = None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, Image.Image):
|
||||
if init_image.mode != 'RGB':
|
||||
init_image = init_image.convert('RGB')
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
||||
|
||||
t_enc = steps
|
||||
|
||||
if init_image is not None and mask_image is not None: # inpainting
|
||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
masked_image = init_image
|
||||
|
||||
else: # txt2img
|
||||
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
self.init_latent = init_image
|
||||
height = init_image.shape[2]
|
||||
width = init_image.shape[3]
|
||||
model = self.model
|
||||
|
||||
def make_image(x_T):
|
||||
with torch.no_grad():
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
|
||||
batch = self.make_batch_sd(
|
||||
init_image,
|
||||
mask_image,
|
||||
masked_image,
|
||||
prompt=prompt,
|
||||
device=model.device,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != model.masked_image_key:
|
||||
bchw = [num_samples, 4, height//8, width//8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = cond,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc_full,
|
||||
eta = 1.0,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
)
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def make_batch_sd(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
masked_image,
|
||||
prompt,
|
||||
device,
|
||||
num_samples=1):
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [prompt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
if self.init_latent is not None:
|
||||
height = self.init_latent.shape[2]
|
||||
width = self.init_latent.shape[3]
|
||||
return Txt2Img.get_noise(self,width,height)
|
@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -19,7 +21,7 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -43,6 +45,7 @@ class Txt2Img(Generator):
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
|
@ -5,9 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from PIL import Image
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -22,31 +24,29 @@ class Txt2Img2Img(Generator):
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
scale_dim = min(width, height)
|
||||
scale = 512 / scale_dim
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
||||
trained_square = 512 * 512
|
||||
actual_square = width * height
|
||||
scale = math.sqrt(trained_square / actual_square)
|
||||
def make_image(x_T):
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
init_height // self.downsampling_factor,
|
||||
init_width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
|
||||
#x = self.get_noise(init_width, init_height)
|
||||
x = x_T
|
||||
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
@ -60,17 +60,18 @@ class Txt2Img2Img(Generator):
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
|
||||
# resizing
|
||||
samples = torch.nn.functional.interpolate(
|
||||
samples,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
samples,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
@ -94,6 +95,8 @@ class Txt2Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
@ -101,8 +104,49 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
def inpaint_make_image(x_T):
|
||||
omnibus = Omnibus(self.model,self.precision)
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
width=init_width,
|
||||
height=init_height,
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||
image = result[0][0]
|
||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||
print(kwargs.pop('init_image',None))
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=interpolated_image,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=result[0][1],
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
return result[0][0]
|
||||
|
||||
if sampler.uses_inpainting_model():
|
||||
return inpaint_make_image
|
||||
else:
|
||||
return make_image
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
@ -116,7 +160,7 @@ class Txt2Img2Img(Generator):
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
@ -130,3 +174,4 @@ class Txt2Img2Img(Generator):
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
||||
|
@ -13,6 +13,7 @@ import gc
|
||||
import hashlib
|
||||
import psutil
|
||||
import transformers
|
||||
import traceback
|
||||
import os
|
||||
from sys import getrefcount
|
||||
from omegaconf import OmegaConf
|
||||
@ -73,6 +74,7 @@ class ModelCache(object):
|
||||
self.models[model_name]['hash'] = hash
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(traceback.format_exc())
|
||||
print(f'** restoring {self.current_model}')
|
||||
self.get_model(self.current_model)
|
||||
return None
|
||||
|
686
ldm/invoke/prompt_parser.py
Normal file
@ -0,0 +1,686 @@
|
||||
import string
|
||||
from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
|
||||
class Prompt():
|
||||
"""
|
||||
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
|
||||
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
|
||||
that are to be blended together are stored individuall as Prompt objects.
|
||||
|
||||
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
|
||||
to produce a FlattenedPrompt.
|
||||
"""
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
return f"Prompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Prompt and other.children == self.children
|
||||
|
||||
class BaseFragment:
|
||||
pass
|
||||
|
||||
class FlattenedPrompt():
|
||||
"""
|
||||
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
|
||||
"""
|
||||
def __init__(self, parts: list=[]):
|
||||
self.children = []
|
||||
for part in parts:
|
||||
self.append(part)
|
||||
|
||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||
# verify type correctness
|
||||
if type(fragment) is list:
|
||||
for x in fragment:
|
||||
self.append(x)
|
||||
elif issubclass(type(fragment), BaseFragment):
|
||||
self.children.append(fragment)
|
||||
elif type(fragment) is tuple:
|
||||
# upgrade tuples to Fragments
|
||||
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
self.children.append(Fragment(fragment[0], fragment[1]))
|
||||
else:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return len(self.children) == 0 or \
|
||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is FlattenedPrompt and other.children == self.children
|
||||
|
||||
|
||||
class Fragment(BaseFragment):
|
||||
"""
|
||||
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
|
||||
"""
|
||||
def __init__(self, text: str, weight: float=1):
|
||||
assert(type(text) is str)
|
||||
if '\\"' in text or '\\(' in text or '\\)' in text:
|
||||
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
|
||||
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
|
||||
self.text = text
|
||||
self.weight = float(weight)
|
||||
|
||||
def __repr__(self):
|
||||
return "Fragment:'"+self.text+"'@"+str(self.weight)
|
||||
def __eq__(self, other):
|
||||
return type(other) is Fragment \
|
||||
and other.text == self.text \
|
||||
and other.weight == self.weight
|
||||
|
||||
class Attention():
|
||||
"""
|
||||
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
|
||||
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
|
||||
|
||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||
"""
|
||||
def __init__(self, weight: float, children: list):
|
||||
self.weight = weight
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:'{self.children}' @ {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
class CrossAttentionControlledFragment(BaseFragment):
|
||||
pass
|
||||
|
||||
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||
"""
|
||||
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
|
||||
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
|
||||
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
|
||||
the result should be an "edited" image that looks like the "original" image with concepts swapped.
|
||||
|
||||
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
|
||||
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
|
||||
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
|
||||
or it may represent a larger portion of the token sequence:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
|
||||
edited=[Fragment('a smiling dog sitting on a car')])
|
||||
|
||||
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
|
||||
FlattenedPrompt([
|
||||
Fragment('a'),
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
|
||||
Fragment('sitting on a car')
|
||||
])
|
||||
"""
|
||||
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
|
||||
self.original = original
|
||||
self.edited = edited
|
||||
|
||||
default_options = {
|
||||
's_start': 0.0,
|
||||
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
|
||||
't_start': 0.0,
|
||||
't_end': 1.0
|
||||
}
|
||||
merged_options = default_options
|
||||
if options is not None:
|
||||
shape_freedom = options.pop('shape_freedom', None)
|
||||
if shape_freedom is not None:
|
||||
# high shape freedom = SD can do what it wants with the shape of the object
|
||||
# high shape freedom => s_end = 0
|
||||
# low shape freedom => s_end = 1
|
||||
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
|
||||
# and there is very little perceptible difference as s_end increases above 0.5
|
||||
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
||||
# -> cube root and subtract from 1.0
|
||||
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
||||
#print('converted shape_freedom argument to', merged_options)
|
||||
merged_options.update(options)
|
||||
|
||||
self.options = merged_options
|
||||
|
||||
def __repr__(self):
|
||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlSubstitute \
|
||||
and other.original == self.original \
|
||||
and other.edited == self.edited \
|
||||
and other.options == self.options
|
||||
|
||||
|
||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
||||
def __init__(self, fragment: Fragment):
|
||||
self.fragment = fragment
|
||||
def __repr__(self):
|
||||
return "CrossAttentionControlAppend:",self.fragment
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlAppend \
|
||||
and other.fragment == self.fragment
|
||||
|
||||
|
||||
|
||||
class Conjunction():
|
||||
"""
|
||||
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
|
||||
by weighted sum in latent space.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list = None):
|
||||
# force everything to be a Prompt
|
||||
#print("making conjunction with", parts)
|
||||
self.prompts = [x if (type(x) is Prompt
|
||||
or type(x) is Blend
|
||||
or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
||||
if len(self.weights) != len(self.prompts):
|
||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||
self.type = 'AND'
|
||||
|
||||
def __repr__(self):
|
||||
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Conjunction \
|
||||
and other.prompts == self.prompts \
|
||||
and other.weights == self.weights
|
||||
|
||||
|
||||
class Blend():
|
||||
"""
|
||||
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
|
||||
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
|
||||
Prompts.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for c in prompts:
|
||||
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
# upcast all lists to Prompt objects
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.prompts = prompts
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||
def __eq__(self, other):
|
||||
return other.__repr__() == self.__repr__()
|
||||
|
||||
|
||||
class PromptParser():
|
||||
|
||||
class ParsingException(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||
|
||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||
|
||||
|
||||
def parse_conjunction(self, prompt: str) -> Conjunction:
|
||||
'''
|
||||
:param prompt: The prompt string to parse
|
||||
:return: a Conjunction representing the parsed results.
|
||||
'''
|
||||
#print(f"!!parsing '{prompt}'")
|
||||
|
||||
if len(prompt.strip()) == 0:
|
||||
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
||||
|
||||
root = self.conjunction.parse_string(prompt)
|
||||
#print(f"'{prompt}' parsed to root", root)
|
||||
#fused = fuse_fragments(parts)
|
||||
#print("fused to", fused)
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction) -> Conjunction:
|
||||
"""
|
||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||
that can be readily tokenized without the need to walk a complex tree structure.
|
||||
|
||||
:param root: The Conjunction to flatten.
|
||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||
"""
|
||||
|
||||
#print("flattening", root)
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
for x in items:
|
||||
if type(x) is CrossAttentionControlSubstitute:
|
||||
original_fused = fuse_fragments(x.original)
|
||||
edited_fused = fuse_fragments(x.edited)
|
||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
|
||||
else:
|
||||
last_weight = result[-1].weight \
|
||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
||||
else None
|
||||
this_text = x.text
|
||||
this_weight = x.weight
|
||||
if last_weight is not None and last_weight == this_weight:
|
||||
last_text = result[-1].text
|
||||
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
|
||||
else:
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
def flatten_internal(node, weight_scale, results, prefix):
|
||||
#print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults:
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
elif type(node) is Attention:
|
||||
# if node.weight < 1:
|
||||
# todo: inject a blend when flattening attention with weight <1"
|
||||
for index,c in enumerate(node.children):
|
||||
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
|
||||
elif type(node) is Fragment:
|
||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||
elif type(node) is CrossAttentionControlSubstitute:
|
||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||
for prompt in node.prompts:
|
||||
# prompt is a list
|
||||
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
||||
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
|
||||
elif type(node) is Prompt:
|
||||
#print(prefix + "about to flatten Prompt with children", node.children)
|
||||
flattened_prompt = []
|
||||
for child in node.children:
|
||||
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
||||
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
return results
|
||||
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||
|
||||
#print("flattened to", flattened_parts)
|
||||
|
||||
weights = root.weights
|
||||
return Conjunction(flattened_parts, weights)
|
||||
|
||||
|
||||
|
||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quotes = pp.Literal('"').suppress()
|
||||
comma = pp.Literal(",").suppress()
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
|
||||
attention = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
cross_attention_substitute = pp.Forward()
|
||||
|
||||
def make_text_fragment(x):
|
||||
#print("### making fragment for", x)
|
||||
if type(x[0]) is Fragment:
|
||||
assert(False)
|
||||
if type(x) is str:
|
||||
return Fragment(x)
|
||||
elif type(x) is pp.ParseResults or type(x) is list:
|
||||
#print(f'converting {type(x).__name__} to Fragment')
|
||||
return Fragment(' '.join([s for s in x]))
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
|
||||
escapes = []
|
||||
for c in escaped_chars_to_ignore:
|
||||
escapes.append(pp.Literal('\\'+c))
|
||||
return pp.Combine(pp.OneOrMore(
|
||||
pp.MatchFirst(escapes + [pp.CharsNotIn(
|
||||
string.whitespace + escaped_chars_to_ignore,
|
||||
exact=1
|
||||
)])
|
||||
))
|
||||
|
||||
|
||||
|
||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
||||
#print(f"parsing fragment string for {x}")
|
||||
fragment_string = x[0]
|
||||
#print(f"ppparsing fragment string \"{fragment_string}\"")
|
||||
|
||||
if len(fragment_string.strip()) == 0:
|
||||
return Fragment('')
|
||||
|
||||
if in_quotes:
|
||||
# escape unescaped quotes
|
||||
fragment_string = fragment_string.replace('"', '\\"')
|
||||
|
||||
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
||||
try:
|
||||
result = pp.Group(pp.MatchFirst([
|
||||
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
|
||||
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
||||
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
|
||||
#print("parsed to", result)
|
||||
return result
|
||||
except pp.ParseException as e:
|
||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||
raise
|
||||
|
||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
||||
|
||||
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
|
||||
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
|
||||
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
||||
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
||||
|
||||
empty = (
|
||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
||||
|
||||
|
||||
def not_ends_with_swap(x):
|
||||
#print("trying to match:", x)
|
||||
return not x[0].endswith('.swap')
|
||||
|
||||
unquoted_word = (pp.Combine(pp.OneOrMore(
|
||||
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
||||
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
|
||||
)))
|
||||
# don't whitespace when the next word starts with +, eg "badly +formed"
|
||||
+ (pp.White().suppress() |
|
||||
# don't eat +/-
|
||||
pp.NotAny(pp.Word('+') | pp.Word('-'))
|
||||
)
|
||||
)
|
||||
|
||||
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
||||
|
||||
parenthesized_fragment << (lparen +
|
||||
pp.Or([
|
||||
(parenthesized_fragment),
|
||||
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
|
||||
(pp.Combine(pp.OneOrMore(
|
||||
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
||||
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
|
||||
pp.White()
|
||||
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
|
||||
pp.Empty()
|
||||
]) + rparen)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
||||
|
||||
debug_attention = False
|
||||
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention_with_parens = pp.Forward()
|
||||
attention_without_parens = pp.Forward()
|
||||
|
||||
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
|
||||
.set_name("attention_foot")\
|
||||
.set_debug(False)
|
||||
attention_with_parens <<= pp.Group(
|
||||
lparen +
|
||||
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
|
||||
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
|
||||
)
|
||||
+ rparen + attention_with_parens_foot)
|
||||
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
||||
|
||||
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
|
||||
attention_without_parens <<= pp.Group(pp.MatchFirst([
|
||||
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
|
||||
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
||||
+ attention_without_parens_foot#.leave_whitespace()
|
||||
]))
|
||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
||||
|
||||
|
||||
attention << pp.MatchFirst([attention_with_parens,
|
||||
attention_without_parens
|
||||
])
|
||||
attention.set_name('attention')
|
||||
|
||||
def make_attention(x):
|
||||
#print("entered make_attention with", x)
|
||||
children = x[0][:-1]
|
||||
weight_raw = x[0][-1]
|
||||
weight = 1.0
|
||||
if type(weight_raw) is float or type(weight_raw) is int:
|
||||
weight = weight_raw
|
||||
elif type(weight_raw) is str:
|
||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||
weight = pow(base, len(weight_raw))
|
||||
|
||||
#print("making Attention from", children, "with weight", weight)
|
||||
|
||||
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
|
||||
|
||||
attention_with_parens.set_parse_action(make_attention)
|
||||
attention_without_parens.set_parse_action(make_attention)
|
||||
|
||||
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
|
||||
|
||||
# cross-attention control
|
||||
empty_string = ((lparen + rparen) |
|
||||
pp.Literal('""').suppress() |
|
||||
(lparen + pp.Literal('""').suppress() + rparen)
|
||||
).set_parse_action(lambda x: Fragment(""))
|
||||
empty_string.set_name('empty_string')
|
||||
|
||||
# cross attention control
|
||||
debug_cross_attention_control = False
|
||||
original_fragment = pp.MatchFirst([
|
||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
||||
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
|
||||
empty_string.set_debug(debug_cross_attention_control),
|
||||
])
|
||||
# support keyword=number arguments
|
||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
||||
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
||||
edited_fragment = pp.MatchFirst([
|
||||
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
||||
lparen +
|
||||
(quoted_fragment | attention |
|
||||
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
|
||||
) +
|
||||
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
||||
rparen,
|
||||
parenthesized_fragment
|
||||
])
|
||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
|
||||
|
||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
||||
|
||||
def make_cross_attention_substitute(x):
|
||||
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
|
||||
#if len(x>2):
|
||||
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
|
||||
#print("made", cacs)
|
||||
return cacs
|
||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||
|
||||
|
||||
# root prompt definition
|
||||
debug_root_prompt = False
|
||||
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt),
|
||||
attention.set_debug(debug_root_prompt),
|
||||
quoted_fragment.set_debug(debug_root_prompt),
|
||||
parenthesized_fragment.set_debug(debug_root_prompt),
|
||||
unquoted_word.set_debug(debug_root_prompt),
|
||||
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
||||
) + pp.StringEnd()) \
|
||||
.set_name('prompt') \
|
||||
.set_parse_action(lambda x: Prompt(x)) \
|
||||
.set_debug(debug_root_prompt)
|
||||
|
||||
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
||||
#print("parsing test:", prompt.parse_string("eyes--"))
|
||||
|
||||
# weighted blend of prompts
|
||||
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
||||
# int weights.
|
||||
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
||||
|
||||
def make_prompt_from_quoted_string(x):
|
||||
#print(' got quoted prompt', x)
|
||||
|
||||
x_unquoted = x[0][1:-1]
|
||||
if len(x_unquoted.strip()) == 0:
|
||||
# print(' b : just an empty string')
|
||||
return Prompt([Fragment('')])
|
||||
#print(f' b parsing \'{x_unquoted}\'')
|
||||
x_parsed = prompt.parse_string(x_unquoted)
|
||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
||||
return x_parsed[0]
|
||||
|
||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
||||
quoted_prompt.set_name('quoted_prompt')
|
||||
|
||||
debug_blend=False
|
||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
||||
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
|
||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
||||
+ pp.Literal(".blend").suppress()
|
||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
||||
blend.set_debug(debug_blend)
|
||||
|
||||
def make_blend(x):
|
||||
prompts = x[0][0]
|
||||
weights = x[0][1]
|
||||
normalize = True
|
||||
if weights[-1] == 'no_normalize':
|
||||
normalize = False
|
||||
weights = weights[:-1]
|
||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
|
||||
|
||||
blend.set_parse_action(make_blend)
|
||||
|
||||
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
||||
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
||||
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
||||
+ pp.Literal(".and").suppress()
|
||||
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
||||
def make_conjunction(x):
|
||||
parts_raw = x[0][0]
|
||||
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
||||
parts = [part for part in parts_raw]
|
||||
return Conjunction(parts, weights)
|
||||
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
||||
|
||||
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
|
||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||
|
||||
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
||||
conjunction.set_debug(False)
|
||||
|
||||
# top-level is a conjunction of one or more blends or prompts
|
||||
return conjunction, prompt
|
||||
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False, weight=1):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', 'x` ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
@ -89,6 +89,9 @@ class Outcrop(object):
|
||||
def _extend(self,image:Image,pixels:int)-> Image:
|
||||
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
|
||||
|
||||
mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \
|
||||
else pixels *2
|
||||
|
||||
# first paste places old image at top of extended image, stretch
|
||||
# it, and applies a gaussian blur to it
|
||||
# take the top half region, stretch and paste it
|
||||
@ -105,7 +108,9 @@ class Outcrop(object):
|
||||
|
||||
# now make the top part transparent to use as a mask
|
||||
alpha = extended_img.getchannel('A')
|
||||
alpha.paste(0,(0,0,extended_img.width,pixels*2))
|
||||
alpha.paste(0,(0,0,extended_img.width,mask_height))
|
||||
extended_img.putalpha(alpha)
|
||||
|
||||
extended_img.save('outputs/curly_extended.png')
|
||||
|
||||
return extended_img
|
||||
|
@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
238
ldm/models/diffusion/cross_attention_control.py
Normal file
@ -0,0 +1,238 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
class CrossAttentionControl:
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
|
||||
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
|
||||
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
|
||||
"""
|
||||
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
if edited_conditioning is not None:
|
||||
assert len(edit_opcodes) == len(edit_options), \
|
||||
"there must be 1 edit_options dict for each edit_opcodes tuple"
|
||||
non_none_edit_options = [x for x in edit_options if x is not None]
|
||||
assert len(non_none_edit_options)>0, "missing edit_options"
|
||||
if len(non_none_edit_options)>1:
|
||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
||||
self.edit_options = non_none_edit_options[0]
|
||||
|
||||
class Context:
|
||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
@classmethod
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_cross_attention_control(cls, model,
|
||||
cross_attention_control_args: Arguments
|
||||
):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = cross_attention_control_args.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
|
||||
class CrossAttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
@classmethod
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = context.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(cls.CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
to_control.append(cls.CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
||||
return [module for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_function(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
this_attn_slice = attn_slice
|
||||
if self.last_attn_slice_mask is not None:
|
||||
# indices and mask operate on dim=2, no need to slice
|
||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
base_attn_slice_mask = self.last_attn_slice_mask
|
||||
if dim is None:
|
||||
base_attn_slice = base_attn_slice_full
|
||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 0:
|
||||
base_attn_slice = base_attn_slice_full[start:end]
|
||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 1:
|
||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
|
||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
||||
base_attn_slice * base_attn_slice_mask
|
||||
else:
|
||||
if dim is None:
|
||||
attn_slice = self.last_attn_slice
|
||||
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 0:
|
||||
attn_slice = self.last_attn_slice[start:end]
|
||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 1:
|
||||
attn_slice = self.last_attn_slice[:, start:end]
|
||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif dim == 0:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
||||
elif self.last_attn_slice.shape[0] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
||||
assert(self.last_attn_slice.shape[0] == end)
|
||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[start:end] = attn_slice
|
||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
|
||||
elif dim == 1:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif self.last_attn_slice.shape[1] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
||||
assert(self.last_attn_slice.shape[1] == end)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[:, start:end] = attn_slice
|
||||
|
||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||
if dim is None:
|
||||
weights = self.last_attn_slice_weights
|
||||
elif dim == 0:
|
||||
weights = self.last_attn_slice_weights[start:end]
|
||||
elif dim == 1:
|
||||
weights = self.last_attn_slice_weights[:, start:end]
|
||||
attn_slice = attn_slice * weights
|
||||
|
||||
return attn_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
@ -1,10 +1,7 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
@ -12,6 +9,21 @@ class DDIMSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps,device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
@ -29,6 +41,7 @@ class DDIMSampler(Sampler):
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
@ -37,16 +50,17 @@ class DDIMSampler(Sampler):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||
x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index
|
||||
)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
|
@ -19,6 +19,7 @@ from functools import partial
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from omegaconf import ListConfig
|
||||
import urllib
|
||||
|
||||
from ldm.util import (
|
||||
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self.model)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
self.use_scheduler = scheduler_config is not None
|
||||
if self.use_scheduler:
|
||||
@ -820,21 +821,21 @@ class LatentDiffusion(DDPM):
|
||||
)
|
||||
return self.scale_factor * z
|
||||
|
||||
def get_learned_conditioning(self, c):
|
||||
def get_learned_conditioning(self, c, **kwargs):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||
self.cond_stage_model.encode
|
||||
):
|
||||
c = self.cond_stage_model.encode(
|
||||
c, embedding_manager=self.embedding_manager
|
||||
c, embedding_manager=self.embedding_manager,**kwargs
|
||||
)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
else:
|
||||
c = self.cond_stage_model(c)
|
||||
c = self.cond_stage_model(c, **kwargs)
|
||||
else:
|
||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
|
||||
return c
|
||||
|
||||
def meshgrid(self, h, w):
|
||||
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self,
|
||||
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
cond_img = torch.stack(bbox_imgs, dim=0)
|
||||
logs['bbox_image'] = cond_img
|
||||
return logs
|
||||
|
||||
class LatentInpaintDiffusion(LatentDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
finetune_keys=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.masked_image_key = masked_image_key
|
||||
assert self.masked_image_key in concat_keys
|
||||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(
|
||||
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||
):
|
||||
# note: restricted to non-trainable encoders currently
|
||||
assert (
|
||||
not self.cond_stage_trainable
|
||||
), "trainable cond stages not yet supported for inpainting"
|
||||
z, c, x, xrec, xc = super().get_input(
|
||||
batch,
|
||||
self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=True,
|
||||
return_original_cond=True,
|
||||
bs=bs,
|
||||
)
|
||||
|
||||
assert exists(self.concat_keys)
|
||||
c_cat = list()
|
||||
for ck in self.concat_keys:
|
||||
cc = (
|
||||
rearrange(batch[ck], "b h w c -> b c h w")
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
if bs is not None:
|
||||
cc = cc[:bs]
|
||||
cc = cc.to(self.device)
|
||||
bchw = z.shape
|
||||
if ck != self.masked_image_key:
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
if return_first_stage_outputs:
|
||||
return z, all_conds, x, xrec, xc
|
||||
return z, all_conds
|
||||
|
@ -1,16 +1,16 @@
|
||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.util import rand_perlin_2d
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
# at this threshold, the scheduler will stop using the Karras
|
||||
# noise schedule and start using the model's schedule
|
||||
STEP_THRESHOLD = 29
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
@ -33,12 +33,21 @@ class CFGDenoiser(nn.Module):
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
@ -46,8 +55,7 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
|
||||
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
class KSampler(Sampler):
|
||||
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
||||
@ -60,16 +68,9 @@ class KSampler(Sampler):
|
||||
self.sigmas = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
|
||||
if self.karras_max is None:
|
||||
self.karras_max = STEP_THRESHOLD
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
@ -98,8 +99,13 @@ class KSampler(Sampler):
|
||||
rho=7.,
|
||||
device=self.device,
|
||||
)
|
||||
self.sigmas = self.model_sigmas
|
||||
#self.sigmas = self.karras_sigmas
|
||||
|
||||
if ddim_num_steps >= self.karras_max:
|
||||
print(f'>> Ksampler using model noise schedule (steps > {self.karras_max})')
|
||||
self.sigmas = self.model_sigmas
|
||||
else:
|
||||
print(f'>> Ksampler using karras noise schedule (steps <= {self.karras_max})')
|
||||
self.sigmas = self.karras_sigmas
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
# means that inpainting will not work. To get this to work we need to be able to
|
||||
@ -118,6 +124,7 @@ class KSampler(Sampler):
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
**kwargs
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
@ -129,7 +136,8 @@ class KSampler(Sampler):
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask
|
||||
mask = mask,
|
||||
**kwargs
|
||||
)
|
||||
return samples
|
||||
|
||||
@ -163,6 +171,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -181,7 +190,6 @@ class KSampler(Sampler):
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
total_steps = len(self.sigmas)
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
@ -195,19 +203,21 @@ class KSampler(Sampler):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
return (
|
||||
sampling_result = (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
# a workaround is found.
|
||||
@ -220,6 +230,7 @@ class KSampler(Sampler):
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
@ -245,6 +256,7 @@ class KSampler(Sampler):
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
@ -269,7 +281,7 @@ class KSampler(Sampler):
|
||||
else:
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc):
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
self.ds = None
|
||||
@ -281,3 +293,6 @@ class KSampler(Sampler):
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.inner_model.model.conditioning_key
|
||||
|
||||
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
@ -13,6 +14,18 @@ class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
@ -32,6 +45,7 @@ class PLMSSampler(Sampler):
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
@ -41,18 +55,15 @@ class PLMSSampler(Sampler):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(
|
||||
x_in, t_in, c_in
|
||||
).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
|
@ -2,13 +2,13 @@
|
||||
ldm.models.diffusion.sampler
|
||||
|
||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
@ -24,6 +24,8 @@ class Sampler(object):
|
||||
self.ddpm_num_timesteps = steps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
@ -158,6 +160,18 @@ class Sampler(object):
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# check to see if make_schedule() has run, and if not, run it
|
||||
if self.ddim_timesteps is None:
|
||||
self.make_schedule(
|
||||
@ -190,10 +204,11 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
steps=S,
|
||||
**kwargs
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
#torch.no_grad()
|
||||
@torch.no_grad()
|
||||
def do_sampling(
|
||||
self,
|
||||
cond,
|
||||
@ -214,6 +229,7 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
steps=None,
|
||||
**kwargs
|
||||
):
|
||||
b = shape[0]
|
||||
time_range = (
|
||||
@ -231,7 +247,7 @@ class Sampler(object):
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
self.prepare_to_sample(t_enc=total_steps)
|
||||
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||
img = self.get_initial_image(x_T,shape,total_steps)
|
||||
|
||||
# probably don't need this at all
|
||||
@ -274,6 +290,7 @@ class Sampler(object):
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
step_count=steps
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
|
||||
@ -305,8 +322,9 @@ class Sampler(object):
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
all_timesteps_count = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
@ -321,7 +339,7 @@ class Sampler(object):
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(t_enc=total_steps)
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
@ -353,6 +371,7 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
t_next = ts_next,
|
||||
step_count=len(self.ddim_timesteps)
|
||||
)
|
||||
|
||||
x_dec, pred_x0, e_t = outs
|
||||
@ -411,3 +430,21 @@ class Sampler(object):
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
'''
|
||||
return self.model.q_sample(x0,ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.model.conditioning_key
|
||||
|
||||
def uses_inpainting_model(self)->bool:
|
||||
return self.conditioning_key() in ('hybrid','concat')
|
||||
|
||||
def adjust_settings(self,**kwargs):
|
||||
'''
|
||||
This is a catch-all method for adjusting any instance variables
|
||||
after the sampler is instantiated. No type-checking performed
|
||||
here, so use with care!
|
||||
'''
|
||||
for k in kwargs.keys():
|
||||
try:
|
||||
setattr(self,k,kwargs[k])
|
||||
except AttributeError:
|
||||
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')
|
||||
|
216
ldm/models/diffusion/shared_invokeai_diffusion.py
Normal file
@ -0,0 +1,216 @@
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross attention control ("prompt2prompt")
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
|
||||
self.cross_attention_control_args = cross_attention_control_args
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
def __init__(self, model, model_forward_callback:
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = CrossAttentionControl.Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model,
|
||||
cross_attention_control_args=self.conditioning.cross_attention_control_args
|
||||
)
|
||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
||||
#todo: apply edit_options using step_count
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
|
||||
elif wants_cross_attention_control:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
||||
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
|
||||
return combined_next_x
|
||||
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
|
||||
both_conditionings).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = dict()
|
||||
for k in conditioning:
|
||||
if isinstance(conditioning[k], list):
|
||||
both_conditionings[k] = [
|
||||
torch.cat([unconditioning[k][i], conditioning[k][i]])
|
||||
for i in range(len(conditioning[k]))
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
||||
# flip because sigmas[0] is for the fully denoised image
|
||||
# percent_through must be <1
|
||||
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -1,5 +1,7 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
|
||||
return x+h_
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
@ -170,46 +173,71 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_compvis(self, q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
self.attention_slice_wrangler = None
|
||||
|
||||
def einsum_op_slice_0(self, q, k, v, slice_size):
|
||||
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
self is the current CrossAttention module for which the callback is being invoked.
|
||||
attention_scores are the scores for attention
|
||||
suggested_attention_slice is a softmax(dim=-1) over attention_scores
|
||||
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If dim is >= 0, offset and slice_size specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attenion slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
if self.attention_slice_wrangler is not None:
|
||||
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(self, q, k, v, slice_size):
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_1(q, k, v, slice_size)
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_0(q, k, v, 1)
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
@ -221,7 +249,7 @@ class CrossAttention(nn.Module):
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def einsum_op(self, q, k, v):
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
@ -244,8 +272,13 @@ class CrossAttention(nn.Module):
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = self.einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
r = self.get_attention_mem_efficient(q, k, v)
|
||||
|
||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(hidden_states)
|
||||
|
||||
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -454,6 +456,223 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
return_tokens_key = "return_tokens"
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
'''
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
weights shall be applied.
|
||||
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
|
||||
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
|
||||
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
|
||||
'''
|
||||
if self.fragment_weights_key not in kwargs:
|
||||
# fallback to base class implementation
|
||||
return super().forward(text, **kwargs)
|
||||
|
||||
fragment_weights = kwargs[self.fragment_weights_key]
|
||||
# self.transformer doesn't like receiving "fragment_weights" as an argument
|
||||
kwargs.pop(self.fragment_weights_key)
|
||||
|
||||
should_return_tokens = False
|
||||
if self.return_tokens_key in kwargs:
|
||||
should_return_tokens = kwargs.get(self.return_tokens_key, False)
|
||||
# self.transformer doesn't like having extra kwargs
|
||||
kwargs.pop(self.return_tokens_key)
|
||||
|
||||
batch_z = None
|
||||
batch_tokens = None
|
||||
for fragments, weights in zip(text, fragment_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
per_embedding_weights = [1.0]
|
||||
|
||||
# now handle weights <1
|
||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||
# removed.
|
||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||
# so let's use tan(), because:
|
||||
# tan is 0.0 at 0,
|
||||
# 1.0 at PI/4, and
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
if should_return_tokens:
|
||||
return batch_z, batch_tokens
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
tokens = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
if include_start_and_end_markers:
|
||||
return tokens
|
||||
else:
|
||||
return [x[1:-1] for x in tokens]
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
|
||||
:param fragments:
|
||||
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
|
||||
:return:
|
||||
'''
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
weights = [1]
|
||||
item_encodings = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=True,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
#print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
|
||||
if (len(all_tokens) + 2) > self.max_length:
|
||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
per_token_weights = per_token_weights[:self.max_length - 2]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (77 = self.max_length)
|
||||
pad_length = self.max_length - 1 - len(all_tokens)
|
||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
||||
per_token_weights.insert(0, 1)
|
||||
per_token_weights.extend([1] * pad_length)
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
empty_tokens = self.tokenizer([''] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)['input_ids'].to(self.device)
|
||||
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
|
||||
return weighted_z
|
||||
|
||||
else:
|
||||
original_mean = z.mean()
|
||||
z *= batch_weights_expanded
|
||||
after_weighting_mean = z.mean()
|
||||
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
|
||||
mean_correction_factor = original_mean/after_weighting_mean
|
||||
z *= mean_correction_factor
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
|
@ -18,6 +18,7 @@ from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from pyparsing import ParseException
|
||||
|
||||
# global used in multiple functions (fix)
|
||||
infile = None
|
||||
@ -172,8 +173,7 @@ def main_loop(gen, opt):
|
||||
pass
|
||||
|
||||
if len(opt.prompt) == 0:
|
||||
print('\nTry again with a prompt!')
|
||||
continue
|
||||
opt.prompt = ''
|
||||
|
||||
# width and height are set by model if not specified
|
||||
if not opt.width:
|
||||
@ -328,12 +328,16 @@ def main_loop(gen, opt):
|
||||
if operation == 'generate':
|
||||
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
|
||||
opt.last_operation='generate'
|
||||
gen.prompt2image(
|
||||
image_callback=image_writer,
|
||||
step_callback=step_callback,
|
||||
catch_interrupts=catch_ctrl_c,
|
||||
**vars(opt)
|
||||
)
|
||||
try:
|
||||
gen.prompt2image(
|
||||
image_callback=image_writer,
|
||||
step_callback=step_callback,
|
||||
catch_interrupts=catch_ctrl_c,
|
||||
**vars(opt)
|
||||
)
|
||||
except ParseException as e:
|
||||
print('** An error occurred while processing your prompt **')
|
||||
print(f'** {str(e)} **')
|
||||
elif operation == 'postprocess':
|
||||
print(f'>> fixing {opt.prompt}')
|
||||
opt.last_operation = do_postprocess(gen,opt,image_writer)
|
||||
@ -528,12 +532,8 @@ def del_config(model_name:str, gen, opt, completer):
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
yaml_str = gen.model_cache.del_model(model_name)
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,opt.conf)
|
||||
if gen.model_cache.del_model(model_name):
|
||||
gen.model_cache.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
@ -592,7 +592,9 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
|
||||
|
||||
def do_textmask(gen, opt, callback):
|
||||
image_path = opt.prompt
|
||||
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
|
||||
if not os.path.exists(image_path):
|
||||
image_path = os.path.join(opt.outdir,image_path)
|
||||
assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **'
|
||||
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
|
||||
tm = opt.text_mask[0]
|
||||
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5
|
||||
|
440
tests/test_prompt_parser.py
Normal file
@ -0,0 +1,440 @@
|
||||
import unittest
|
||||
|
||||
import pyparsing
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
|
||||
Fragment
|
||||
|
||||
|
||||
def parse_prompt(prompt_string):
|
||||
pp = PromptParser()
|
||||
#print(f"parsing '{prompt_string}'")
|
||||
parse_result = pp.parse_conjunction(prompt_string)
|
||||
#print(f"-> parsed '{prompt_string}' to {parse_result}")
|
||||
return parse_result
|
||||
|
||||
def make_basic_conjunction(strings: list[str]):
|
||||
fragments = [Fragment(x) for x in strings]
|
||||
return Conjunction([FlattenedPrompt(fragments)])
|
||||
|
||||
def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]):
|
||||
fragments = [Fragment(x, w) for x,w in weighted_strings]
|
||||
return Conjunction([FlattenedPrompt(fragments)])
|
||||
|
||||
|
||||
class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_empty(self):
|
||||
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
||||
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
||||
self.assertEqual(make_basic_conjunction(['Dalí']), parse_prompt("Dalí"))
|
||||
|
||||
def test_attention(self):
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
||||
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
|
||||
parse_prompt("(pretty flowers)+"))
|
||||
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
|
||||
parse_prompt("(pretty flowers)+, the flames are too hot"))
|
||||
|
||||
def test_no_parens_attention_runon(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
|
||||
|
||||
|
||||
def test_explicit_conjunction(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
|
||||
self.assertEqual(
|
||||
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
|
||||
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
|
||||
|
||||
def test_conjunction_weights(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
|
||||
|
||||
with self.assertRaises(PromptParser.ParsingException):
|
||||
parse_prompt('("fire", "flames").and(2)')
|
||||
parse_prompt('("fire", "flames").and(2,1,2)')
|
||||
|
||||
def test_complex_conjunction(self):
|
||||
|
||||
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
|
||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
||||
FlattenedPrompt([("a person with a hat", 1.0),
|
||||
("riding a", 1.1*1.1),
|
||||
CrossAttentionControlSubstitute(
|
||||
[Fragment("bicycle", pow(1.1,2))],
|
||||
[Fragment("skateboard", pow(1.1,2))])
|
||||
])
|
||||
], weights=[0.5, 0.5]),
|
||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
||||
|
||||
def test_badly_formed(self):
|
||||
def make_untouched_prompt(prompt):
|
||||
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
|
||||
|
||||
def assert_if_prompt_string_not_untouched(prompt):
|
||||
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
|
||||
|
||||
assert_if_prompt_string_not_untouched('a test prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed test prompt')
|
||||
#with self.assertRaises(pyparsing.ParseException):
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed +test prompt')
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(((a badly (formed +test )prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(a (ba)dly (f)ormed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||
parse_prompt("hamburger ((bun))"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||
parse_prompt("hamburger (bun)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
|
||||
parse_prompt("hamburger (kaiser roll)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
|
||||
parse_prompt("hamburger ((kaiser roll))"))
|
||||
|
||||
|
||||
def test_blend(self):
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
||||
)
|
||||
self.assertEqual(Conjunction([Blend(
|
||||
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
|
||||
[0.7, 0.3, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
|
||||
)
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
|
||||
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
|
||||
FlattenedPrompt([('hi', 1.0)])],
|
||||
weights=[0.7, 0.3, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
|
||||
)
|
||||
# blend a single entry is not a failure
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
|
||||
parse_prompt("(\"fire\").blend(0.7)")
|
||||
)
|
||||
# blend with empty
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
|
||||
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
||||
)
|
||||
|
||||
|
||||
def test_nested(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
|
||||
parse_prompt('fire (flames (trees)1.5)2.0'))
|
||||
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
|
||||
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
|
||||
weights=[1.0, 1.0])]),
|
||||
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
|
||||
|
||||
def test_cross_attention_control(self):
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
|
||||
|
||||
|
||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
||||
|
||||
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
||||
|
||||
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
||||
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
|
||||
(', fire', 1.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap("")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
||||
|
||||
def test_cross_attention_control_with_attention(self):
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||
])]),
|
||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)"))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||
])]),
|
||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)"))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
|
||||
])]),
|
||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
|
||||
])]),
|
||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)"))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(0.9,1))])
|
||||
])]),
|
||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
|
||||
|
||||
|
||||
def test_cross_attention_control_options(self):
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
|
||||
Fragment('eating a hotdog', 1)])]),
|
||||
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
|
||||
Fragment('eating a hotdog', 1)])]),
|
||||
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
|
||||
Fragment('eating a hotdog', 1)])]),
|
||||
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
|
||||
|
||||
self.assertEqual(
|
||||
Conjunction([
|
||||
FlattenedPrompt([Fragment('a fantasy forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)],
|
||||
options={'s_start': 0.8, 't_start': 0.8})])]),
|
||||
parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)"))
|
||||
|
||||
|
||||
def test_escaping(self):
|
||||
|
||||
# make sure ", ( and ) can be escaped
|
||||
|
||||
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
|
||||
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
|
||||
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
|
||||
# same weights for each are combined into one
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
|
||||
# same weights for each are combined into one
|
||||
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
|
||||
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
|
||||
|
||||
def test_cross_attention_escaping(self):
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain (man).swap(monkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (man).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
|
||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain ("man").swap(monkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain ("man").swap("monkey")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain (\\"man).swap("monkey")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (man).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
|
||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
||||
|
||||
def test_legacy_blend(self):
|
||||
pp = PromptParser()
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain man:1 man mountain:1'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain+ man:1 man mountain-:1'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain+ man:1 man mountain-'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain+ man: man mountain-:'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[0.75,0.25]),
|
||||
pp.parse_legacy_blend('mountain man:3 man mountain:1'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[1.0,0.0]),
|
||||
pp.parse_legacy_blend('mountain man:3 man mountain:0'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[0.8,0.2]),
|
||||
pp.parse_legacy_blend('"mountain man":4 man mountain'))
|
||||
|
||||
|
||||
def test_single(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
1
tests/validate_pr_prompt.txt
Normal file
@ -0,0 +1 @@
|
||||
banana sushi -Ak_lms -S42 -s10
|