diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..8a0ebc5069 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +* +!environment*.yml +!docker-build diff --git a/.github/workflows/build-container.yml b/.github/workflows/build-container.yml new file mode 100644 index 0000000000..036ad9679d --- /dev/null +++ b/.github/workflows/build-container.yml @@ -0,0 +1,43 @@ +# Building the Image without pushing to confirm it is still buildable +# confirum functionality would unfortunately need way more resources +name: build container image +on: + push: + branches: + - 'main' + - 'development' + pull_request: + branches: + - 'main' + - 'development' + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - name: prepare docker-tag + env: + repository: ${{ github.repository }} + run: echo "dockertag=${repository,,}" >> $GITHUB_ENV + - name: Checkout + uses: actions/checkout@v3 + - name: Set up QEMU + uses: docker/setup-qemu-action@v2 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Cache Docker layers + uses: actions/cache@v2 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ github.sha }} + restore-keys: ${{ runner.os }}-buildx- + - name: Build container + uses: docker/build-push-action@v3 + with: + context: . + file: docker-build/Dockerfile + platforms: linux/amd64 + push: false + tags: ${{ env.dockertag }}:latest + cache-from: type=local,src=/tmp/.buildx-cache + cache-to: type=local,dest=/tmp/.buildx-cache diff --git a/.github/workflows/create-caches.yml b/.github/workflows/create-caches.yml index bbd95f58d8..33ea3b82ed 100644 --- a/.github/workflows/create-caches.yml +++ b/.github/workflows/create-caches.yml @@ -54,27 +54,9 @@ jobs: [[ -d models/ldm/stable-diffusion-v1 ]] \ || mkdir -p models/ldm/stable-diffusion-v1 [[ -r models/ldm/stable-diffusion-v1/model.ckpt ]] \ - || curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }} - - - name: Use cached Conda Environment - uses: actions/cache@v3 - env: - cache-name: cache-conda-env-${{ env.CONDA_ENV_NAME }} - conda-env-file: ${{ matrix.environment-file }} - with: - path: ${{ env.CONDA_ROOT }}/envs/${{ env.CONDA_ENV_NAME }} - key: ${{ env.cache-name }} - restore-keys: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }} - - - name: Use cached Conda Packages - uses: actions/cache@v3 - env: - cache-name: cache-conda-env-${{ env.CONDA_ENV_NAME }} - conda-env-file: ${{ matrix.environment-file }} - with: - path: ${{ env.CONDA_PKGS_DIR }} - key: ${{ env.cache-name }} - restore-keys: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }} + || curl --user "${{ secrets.HUGGINGFACE_TOKEN }}" \ + -o models/ldm/stable-diffusion-v1/model.ckpt \ + -O -L https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt - name: Activate Conda Env uses: conda-incubator/setup-miniconda@v2 diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index e2643facfa..65e3e4e1e2 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -4,8 +4,7 @@ on: branches: - 'main' - 'development' - - 'fix-gh-actions-fork' - pull_request: + pull_request_target: branches: - 'main' - 'development' @@ -13,15 +12,18 @@ on: jobs: os_matrix: strategy: + fail-fast: false matrix: os: [ubuntu-latest, macos-latest] include: - os: ubuntu-latest environment-file: environment.yml default-shell: bash -l {0} + stable-diffusion-model: https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt - os: macos-latest environment-file: environment-mac.yml default-shell: bash -l {0} + stable-diffusion-model: https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt name: Test invoke.py on ${{ matrix.os }} with conda runs-on: ${{ matrix.os }} defaults: @@ -48,7 +50,7 @@ jobs: - name: set test prompt to Pull Request validation if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }} - run: echo "TEST_PROMPTS=tests/pr_prompt.txt" >> $GITHUB_ENV + run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV - name: set conda environment name run: echo "CONDA_ENV_NAME=invokeai" >> $GITHUB_ENV @@ -57,7 +59,7 @@ jobs: id: cache-sd-v1-4 uses: actions/cache@v3 env: - cache-name: cache-sd-v1-4 + cache-name: cache-sd-${{ matrix.stable-diffusion-model }} with: path: models/ldm/stable-diffusion-v1/model.ckpt key: ${{ env.cache-name }} @@ -69,25 +71,9 @@ jobs: [[ -d models/ldm/stable-diffusion-v1 ]] \ || mkdir -p models/ldm/stable-diffusion-v1 [[ -r models/ldm/stable-diffusion-v1/model.ckpt ]] \ - || curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }} - - - name: Use cached Conda Environment - uses: actions/cache@v3 - env: - cache-name: cache-conda-env-${{ env.CONDA_ENV_NAME }} - conda-env-file: ${{ matrix.environment-file }} - with: - path: ${{ env.CONDA }}/envs/${{ env.CONDA_ENV_NAME }} - key: env-${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }} - - - name: Use cached Conda Packages - uses: actions/cache@v3 - env: - cache-name: cache-conda-pkgs-${{ env.CONDA_ENV_NAME }} - conda-env-file: ${{ matrix.environment-file }} - with: - path: ${{ env.CONDA_PKGS_DIR }} - key: pkgs-${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(env.conda-env-file) }} + || curl --user "${{ secrets.HUGGINGFACE_TOKEN }}" \ + -o models/ldm/stable-diffusion-v1/model.ckpt \ + -O -L ${{ matrix.stable-diffusion-model }} - name: Activate Conda Env uses: conda-incubator/setup-miniconda@v2 diff --git a/.gitignore b/.gitignore index c36ba7e4a0..ecef2713bc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ outputs/ models/ldm/stable-diffusion-v1/model.ckpt ldm/invoke/restoration/codeformer/weights +# ignore user models config +configs/models.user.yaml +config/models.user.yml + # ignore the Anaconda/Miniconda installer used while building Docker image anaconda.sh diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 96ecda1af1..dabe072f80 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -14,7 +14,7 @@ from threading import Event from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from backend.modules.parameters import parameters_to_command diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..8ad861356c 100644 --- a/backend/server.py +++ b/backend/server.py @@ -33,7 +33,7 @@ from ldm.generate import Generate from ldm.invoke.restoration import Restoration from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from modules.parameters import parameters_to_command diff --git a/configs/models.yaml b/configs/models.yaml index 332ee26409..162da38da2 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -13,6 +13,13 @@ stable-diffusion-1.4: width: 512 height: 512 default: true +inpainting-1.5: + description: runwayML tuned inpainting model v1.5 + weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt + config: configs/stable-diffusion/v1-inpainting-inference.yaml +# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt + width: 512 + height: 512 stable-diffusion-1.5: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index 9c773077b6..baf91f6e26 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -76,4 +76,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/configs/stable-diffusion/v1-inpainting-inference.yaml b/configs/stable-diffusion/v1-inpainting-inference.yaml new file mode 100644 index 0000000000..3ea164a359 --- /dev/null +++ b/configs/stable-diffusion/v1-inpainting-inference.yaml @@ -0,0 +1,79 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ['face', 'man', 'photo', 'africanmale'] + per_image_tokens: false + num_vectors_per_token: 1 + progressive_words: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/docker-build/Dockerfile b/docker-build/Dockerfile index f3d6834c93..2f0c892730 100644 --- a/docker-build/Dockerfile +++ b/docker-build/Dockerfile @@ -1,57 +1,74 @@ -FROM debian +FROM ubuntu AS get_miniconda -ARG gsd -ENV GITHUB_STABLE_DIFFUSION $gsd +SHELL ["/bin/bash", "-c"] -ARG rsd -ENV REQS $rsd +# install wget +RUN apt-get update \ + && apt-get install -y \ + wget \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* -ARG cs -ENV CONDA_SUBDIR $cs +# download and install miniconda +ARG conda_version=py39_4.12.0-Linux-x86_64 +ARG conda_prefix=/opt/conda +RUN wget --progress=dot:giga -O /miniconda.sh \ + https://repo.anaconda.com/miniconda/Miniconda3-${conda_version}.sh \ + && bash /miniconda.sh -b -p ${conda_prefix} \ + && rm -f /miniconda.sh -ENV PIP_EXISTS_ACTION="w" +FROM ubuntu AS invokeai -# TODO: Optimize image size +# use bash +SHELL [ "/bin/bash", "-c" ] -SHELL ["/bin/bash", "-c"] +# clean bashrc +RUN echo "" > ~/.bashrc -WORKDIR / -RUN apt update && apt upgrade -y \ - && apt install -y \ - git \ - libgl1-mesa-glx \ - libglib2.0-0 \ - pip \ - python3 \ - && git clone $GITHUB_STABLE_DIFFUSION +# Install necesarry packages +RUN apt-get update \ + && apt-get install -y \ + --no-install-recommends \ + gcc \ + git \ + libgl1-mesa-glx \ + libglib2.0-0 \ + pip \ + python3 \ + python3-dev \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* -# Install Anaconda or Miniconda -COPY anaconda.sh . -RUN bash anaconda.sh -b -u -p /anaconda && /anaconda/bin/conda init bash +# clone repository and create symlinks +ARG invokeai_git=https://github.com/invoke-ai/InvokeAI.git +ARG project_name=invokeai +RUN git clone ${invokeai_git} /${project_name} \ + && mkdir /${project_name}/models/ldm/stable-diffusion-v1 \ + && ln -s /data/models/sd-v1-4.ckpt /${project_name}/models/ldm/stable-diffusion-v1/model.ckpt \ + && ln -s /data/outputs/ /${project_name}/outputs -# SD -WORKDIR /stable-diffusion -RUN source ~/.bashrc \ - && conda create -y --name ldm && conda activate ldm \ - && conda config --env --set subdir $CONDA_SUBDIR \ - && pip3 install -r $REQS \ - && pip3 install basicsr facexlib realesrgan \ - && mkdir models/ldm/stable-diffusion-v1 \ - && ln -s "/data/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt +# set workdir +WORKDIR /${project_name} -# Face restoreation -# by default expected in a sibling directory to stable-diffusion -WORKDIR / -RUN git clone https://github.com/TencentARC/GFPGAN.git +# install conda env and preload models +ARG conda_prefix=/opt/conda +ARG conda_env_file=environment.yml +COPY --from=get_miniconda ${conda_prefix} ${conda_prefix} +RUN source ${conda_prefix}/etc/profile.d/conda.sh \ + && conda init bash \ + && source ~/.bashrc \ + && conda env create \ + --name ${project_name} \ + --file ${conda_env_file} \ + && rm -Rf ~/.cache \ + && conda clean -afy \ + && echo "conda activate ${project_name}" >> ~/.bashrc \ + && ln -s /data/models/GFPGANv1.4.pth ./src/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth \ + && conda activate ${project_name} \ + && python scripts/preload_models.py -WORKDIR /GFPGAN -RUN pip3 install -r requirements.txt \ - && python3 setup.py develop \ - && ln -s "/data/GFPGANv1.4.pth" experiments/pretrained_models/GFPGANv1.4.pth - -WORKDIR /stable-diffusion -RUN python3 scripts/preload_models.py - -WORKDIR / -COPY entrypoint.sh . -ENTRYPOINT ["/entrypoint.sh"] \ No newline at end of file +# Copy entrypoint and set env +ENV CONDA_PREFIX=${conda_prefix} +ENV PROJECT_NAME=${project_name} +COPY docker-build/entrypoint.sh / +ENTRYPOINT [ "/entrypoint.sh" ] diff --git a/docker-build/build.sh b/docker-build/build.sh new file mode 100755 index 0000000000..fea311ae82 --- /dev/null +++ b/docker-build/build.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +set -e +# IMPORTANT: You need to have a token on huggingface.co to be able to download the checkpoint!!! +# configure values by using env when executing build.sh +# f.e. env ARCH=aarch64 GITHUB_INVOKE_AI=https://github.com/yourname/yourfork.git ./build.sh + +source ./docker-build/env.sh || echo "please run from repository root" || exit 1 + +invokeai_conda_version=${INVOKEAI_CONDA_VERSION:-py39_4.12.0-${platform/\//-}} +invokeai_conda_prefix=${INVOKEAI_CONDA_PREFIX:-\/opt\/conda} +invokeai_conda_env_file=${INVOKEAI_CONDA_ENV_FILE:-environment.yml} +invokeai_git=${INVOKEAI_GIT:-https://github.com/invoke-ai/InvokeAI.git} +huggingface_token=${HUGGINGFACE_TOKEN?} + +# print the settings +echo "You are using these values:" +echo -e "project_name:\t\t ${project_name}" +echo -e "volumename:\t\t ${volumename}" +echo -e "arch:\t\t\t ${arch}" +echo -e "platform:\t\t ${platform}" +echo -e "invokeai_conda_version:\t ${invokeai_conda_version}" +echo -e "invokeai_conda_prefix:\t ${invokeai_conda_prefix}" +echo -e "invokeai_conda_env_file: ${invokeai_conda_env_file}" +echo -e "invokeai_git:\t\t ${invokeai_git}" +echo -e "invokeai_tag:\t\t ${invokeai_tag}\n" + +_runAlpine() { + docker run \ + --rm \ + --interactive \ + --tty \ + --mount source="$volumename",target=/data \ + --workdir /data \ + alpine "$@" +} + +_copyCheckpoints() { + echo "creating subfolders for models and outputs" + _runAlpine mkdir models + _runAlpine mkdir outputs + echo -n "downloading sd-v1-4.ckpt" + _runAlpine wget --header="Authorization: Bearer ${huggingface_token}" -O models/sd-v1-4.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt + echo "done" + echo "downloading GFPGANv1.4.pth" + _runAlpine wget -O models/GFPGANv1.4.pth https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth +} + +_checkVolumeContent() { + _runAlpine ls -lhA /data/models +} + +_getModelMd5s() { + _runAlpine \ + alpine sh -c "md5sum /data/models/*" +} + +if [[ -n "$(docker volume ls -f name="${volumename}" -q)" ]]; then + echo "Volume already exists" + if [[ -z "$(_checkVolumeContent)" ]]; then + echo "looks empty, copying checkpoint" + _copyCheckpoints + fi + echo "Models in ${volumename}:" + _checkVolumeContent +else + echo -n "createing docker volume " + docker volume create "${volumename}" + _copyCheckpoints +fi + +# Build Container +docker build \ + --platform="${platform}" \ + --tag "${invokeai_tag}" \ + --build-arg project_name="${project_name}" \ + --build-arg conda_version="${invokeai_conda_version}" \ + --build-arg conda_prefix="${invokeai_conda_prefix}" \ + --build-arg conda_env_file="${invokeai_conda_env_file}" \ + --build-arg invokeai_git="${invokeai_git}" \ + --file ./docker-build/Dockerfile \ + . diff --git a/docker-build/entrypoint.sh b/docker-build/entrypoint.sh index f47e6669e0..7c0ca12f88 100755 --- a/docker-build/entrypoint.sh +++ b/docker-build/entrypoint.sh @@ -1,10 +1,8 @@ #!/bin/bash +set -e -cd /stable-diffusion +source "${CONDA_PREFIX}/etc/profile.d/conda.sh" +conda activate "${PROJECT_NAME}" -if [ $# -eq 0 ]; then - python3 scripts/dream.py --full_precision -o /data - # bash -else - python3 scripts/dream.py --full_precision -o /data "$@" -fi \ No newline at end of file +python scripts/invoke.py \ + ${@:---web --host=0.0.0.0} diff --git a/docker-build/env.sh b/docker-build/env.sh new file mode 100644 index 0000000000..36b718d362 --- /dev/null +++ b/docker-build/env.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +project_name=${PROJECT_NAME:-invokeai} +volumename=${VOLUMENAME:-${project_name}_data} +arch=${ARCH:-x86_64} +platform=${PLATFORM:-Linux/${arch}} +invokeai_tag=${INVOKEAI_TAG:-${project_name}-${arch}} + +export project_name +export volumename +export arch +export platform +export invokeai_tag diff --git a/docker-build/run.sh b/docker-build/run.sh new file mode 100755 index 0000000000..3d1a564f4c --- /dev/null +++ b/docker-build/run.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +set -e + +source ./docker-build/env.sh || echo "please run from repository root" || exit 1 + +docker run \ + --interactive \ + --tty \ + --rm \ + --platform "$platform" \ + --name "$project_name" \ + --hostname "$project_name" \ + --mount source="$volumename",target=/data \ + --publish 9090:9090 \ + "$invokeai_tag" ${1:+$@} diff --git a/docs/assets/prompt_syntax/apricots--1.png b/docs/assets/prompt_syntax/apricots--1.png new file mode 100644 index 0000000000..0f0c17f08b Binary files /dev/null and b/docs/assets/prompt_syntax/apricots--1.png differ diff --git a/docs/assets/prompt_syntax/apricots--2.png b/docs/assets/prompt_syntax/apricots--2.png new file mode 100644 index 0000000000..5c519b09ae Binary files /dev/null and b/docs/assets/prompt_syntax/apricots--2.png differ diff --git a/docs/assets/prompt_syntax/apricots--3.png b/docs/assets/prompt_syntax/apricots--3.png new file mode 100644 index 0000000000..c98ffd8b07 Binary files /dev/null and b/docs/assets/prompt_syntax/apricots--3.png differ diff --git a/docs/assets/prompt_syntax/apricots-0.png b/docs/assets/prompt_syntax/apricots-0.png new file mode 100644 index 0000000000..f8ead74db2 Binary files /dev/null and b/docs/assets/prompt_syntax/apricots-0.png differ diff --git a/docs/assets/prompt_syntax/apricots-1.png b/docs/assets/prompt_syntax/apricots-1.png new file mode 100644 index 0000000000..75ff7a24a3 Binary files /dev/null and b/docs/assets/prompt_syntax/apricots-1.png differ diff --git a/docs/assets/prompt_syntax/apricots-2.png b/docs/assets/prompt_syntax/apricots-2.png new file mode 100644 index 0000000000..e24ced7637 Binary files /dev/null and b/docs/assets/prompt_syntax/apricots-2.png differ diff --git a/docs/assets/prompt_syntax/apricots-3.png b/docs/assets/prompt_syntax/apricots-3.png new file mode 100644 index 0000000000..d6edf5073c Binary files /dev/null and b/docs/assets/prompt_syntax/apricots-3.png differ diff --git a/docs/assets/prompt_syntax/apricots-4.png b/docs/assets/prompt_syntax/apricots-4.png new file mode 100644 index 0000000000..291c5a1b03 Binary files /dev/null and b/docs/assets/prompt_syntax/apricots-4.png differ diff --git a/docs/assets/prompt_syntax/apricots-5.png b/docs/assets/prompt_syntax/apricots-5.png new file mode 100644 index 0000000000..c71b857837 Binary files /dev/null and b/docs/assets/prompt_syntax/apricots-5.png differ diff --git a/docs/assets/prompt_syntax/mountain-man.png b/docs/assets/prompt_syntax/mountain-man.png new file mode 100644 index 0000000000..4bfa5817b8 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain-man.png differ diff --git a/docs/assets/prompt_syntax/mountain-man1.png b/docs/assets/prompt_syntax/mountain-man1.png new file mode 100644 index 0000000000..5ed98162d3 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain-man1.png differ diff --git a/docs/assets/prompt_syntax/mountain-man2.png b/docs/assets/prompt_syntax/mountain-man2.png new file mode 100644 index 0000000000..d4466514de Binary files /dev/null and b/docs/assets/prompt_syntax/mountain-man2.png differ diff --git a/docs/assets/prompt_syntax/mountain-man3.png b/docs/assets/prompt_syntax/mountain-man3.png new file mode 100644 index 0000000000..3196c5ca96 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain-man3.png differ diff --git a/docs/assets/prompt_syntax/mountain-man4.png b/docs/assets/prompt_syntax/mountain-man4.png new file mode 100644 index 0000000000..69522dba23 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain-man4.png differ diff --git a/docs/assets/prompt_syntax/mountain1-man.png b/docs/assets/prompt_syntax/mountain1-man.png new file mode 100644 index 0000000000..b5952d02f9 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain1-man.png differ diff --git a/docs/assets/prompt_syntax/mountain2-man.png b/docs/assets/prompt_syntax/mountain2-man.png new file mode 100644 index 0000000000..8ab55ff2a7 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain2-man.png differ diff --git a/docs/assets/prompt_syntax/mountain3-man.png b/docs/assets/prompt_syntax/mountain3-man.png new file mode 100644 index 0000000000..c1024b0a63 Binary files /dev/null and b/docs/assets/prompt_syntax/mountain3-man.png differ diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 67a187fb3b..8e7b5780f0 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -153,6 +153,7 @@ Here are the invoke> command that apply to txt2img: | --cfg_scale | -C | 7.5 | How hard to try to match the prompt to the generated image; any number greater than 1.0 works, but the useful range is roughly 5.0 to 20.0 | | --seed | -S | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.| | --sampler | -A| k_lms | Sampler to use. Use -h to get list of available samplers. | +| --karras_max | | 29 | When using k_* samplers, set the maximum number of steps before shifting from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts) This value is sticky. [29] | | --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution | | --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt | @@ -218,8 +219,13 @@ well as the --mask (-M) and --text_mask (-tm) arguments: | Argument | Shortcut | Default | Description | |--------------------|------------|---------------------|--------------| | `--init_mask ` | `-M` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.| +| `--invert_mask ` | | False |If true, invert the mask so that transparent areas are opaque and vice versa.| | `--text_mask []` | `-tm []` | | Create a mask from a text prompt describing part of the image| +The mask may either be an image with transparent areas, in which case +the inpainting will occur in the transparent areas only, or a black +and white image, in which case all black areas will be painted into. + `--text_mask` (short form `-tm`) is a way to generate a mask using a text description of the part of the image to replace. For example, if you have an image of a breakfast plate with a bagel, toast and diff --git a/docs/features/IMG2IMG.md b/docs/features/IMG2IMG.md index a540ff5cc9..99c72d2a73 100644 --- a/docs/features/IMG2IMG.md +++ b/docs/features/IMG2IMG.md @@ -121,8 +121,6 @@ Both of the outputs look kind of like what I was thinking of. With the strength If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `"fire"`: -If you want to try this out yourself, all of these are using a seed of `1592514025` with a width/height of `384`, step count `10`, the default sampler (`k_lms`), and the single-word prompt `fire`: - ```commandline invoke> "fire" -s10 -W384 -H384 -S1592514025 -I /tmp/fire-drawing.png --strength 0.7 ``` diff --git a/docs/features/INPAINTING.md b/docs/features/INPAINTING.md index 0503f1fc88..0308243b2f 100644 --- a/docs/features/INPAINTING.md +++ b/docs/features/INPAINTING.md @@ -34,6 +34,16 @@ original unedited image and the masked (partially transparent) image: invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png ``` +If you are using Photoshop to make your transparent masks, here is a +protocol contributed by III_Communication36 (Discord name): + + Create your alpha channel for mask in photoshop, then run + image/adjust/threshold on that channel. Export as Save a copy using + superpng (3rd party free download plugin) making sure alpha channel + is selected. Then masking works as it should for the img2img + process 100%. Can feed just one image this way without needing to + feed the -M mask behind it + ## **Masking using Text** You can also create a mask using a text prompt to select the part of @@ -139,7 +149,83 @@ region directly: invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20 ``` -### Inpainting is not changing the masked region enough! +## Using the RunwayML inpainting model + +The [RunwayML Inpainting Model +v1.5](https://huggingface.co/runwayml/stable-diffusion-inpainting) is +a specialized version of [Stable Diffusion +v1.5](https://huggingface.co/spaces/runwayml/stable-diffusion-v1-5) +that contains extra channels specifically designed to enhance +inpainting and outpainting. While it can do regular `txt2img` and +`img2img`, it really shines when filling in missing regions. It has an +almost uncanny ability to blend the new regions with existing ones in +a semantically coherent way. + +To install the inpainting model, follow the +[instructions](INSTALLING-MODELS.md) for installing a new model. You +may use either the CLI (`invoke.py` script) or directly edit the +`configs/models.yaml` configuration file to do this. The main thing to +watch out for is that the the model `config` option must be set up to +use `v1-inpainting-inference.yaml` rather than the `v1-inference.yaml` +file that is used by Stable Diffusion 1.4 and 1.5. + +After installation, your `models.yaml` should contain an entry that +looks like this one: + + inpainting-1.5: + weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt + description: SD inpainting v1.5 + config: configs/stable-diffusion/v1-inpainting-inference.yaml + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt + width: 512 + height: 512 + +As shown in the example, you may include a VAE fine-tuning weights +file as well. This is strongly recommended. + +To use the custom inpainting model, launch `invoke.py` with the +argument `--model inpainting-1.5` or alternatively from within the +script use the `!switch inpainting-1.5` command to load and switch to +the inpainting model. + +You can now do inpainting and outpainting exactly as described above, +but there will (likely) be a noticeable improvement in +coherence. Txt2img and Img2img will work as well. + +There are a few caveats to be aware of: + +1. The inpainting model is larger than the standard model, and will + use nearly 4 GB of GPU VRAM. This makes it unlikely to run on + a 4 GB graphics card. + +2. When operating in Img2img mode, the inpainting model is much less + steerable than the standard model. It is great for making small + changes, such as changing the pattern of a fabric, or slightly + changing a subject's expression or hair, but the model will + resist making the dramatic alterations that the standard + model lets you do. + +3. While the `--hires` option works fine with the inpainting model, + some special features, such as `--embiggen` are disabled. + +4. Prompt weighting (`banana++ sushi`) and merging work well with + the inpainting model, but prompt swapping (a ("fluffy cat").swap("smiling dog") eating a hotdog`) + will not have any effect due to the way the model is set up. + You may use text masking (with `-tm thing-to-mask`) as an + effective replacement. + +5. The model tends to oversharpen image if you use high step or CFG + values. If you need to do large steps, use the standard model. + +6. The `--strength` (`-f`) option has no effect on the inpainting + model due to its fundamental differences with the standard + model. It will always take the full number of steps you specify. + +## Troubleshooting + +Here are some troubleshooting tips for inpainting and outpainting. + +## Inpainting is not changing the masked region enough! One of the things to understand about how inpainting works is that it is equivalent to running img2img on just the masked (transparent) diff --git a/docs/features/OUTPAINTING.md b/docs/features/OUTPAINTING.md index 1f1e1dbdfa..a6de893811 100644 --- a/docs/features/OUTPAINTING.md +++ b/docs/features/OUTPAINTING.md @@ -15,13 +15,52 @@ InvokeAI supports two versions of outpainting, one called "outpaint" and the other "outcrop." They work slightly differently and each has its advantages and drawbacks. +### Outpainting + +Outpainting is the same as inpainting, except that the painting occurs +in the regions outside of the original image. To outpaint using the +`invoke.py` command line script, prepare an image in which the borders +to be extended are pure black. Add an alpha channel (if there isn't one +already), and make the borders completely transparent and the interior +completely opaque. If you wish to modify the interior as well, you may +create transparent holes in the transparency layer, which `img2img` will +paint into as usual. + +Pass the image as the argument to the `-I` switch as you would for +regular inpainting: + + invoke> a stream by a river -I /path/to/transparent_img.png + +You'll likely be delighted by the results. + +### Tips + +1. Do not try to expand the image too much at once. Generally it is best + to expand the margins in 64-pixel increments. 128 pixels often works, + but your mileage may vary depending on the nature of the image you are + trying to outpaint into. + +2. There are a series of switches that can be used to adjust how the + inpainting algorithm operates. In particular, you can use these to + minimize the seam that sometimes appears between the original image + and the extended part. These switches are: + + --seam_size SEAM_SIZE Size of the mask around the seam between original and outpainted image (0) + --seam_blur SEAM_BLUR The amount to blur the seam inwards (0) + --seam_strength STRENGTH The img2img strength to use when filling the seam (0.7) + --seam_steps SEAM_STEPS The number of steps to use to fill the seam. (10) + --tile_size TILE_SIZE The tile size to use for filling outpaint areas (32) + ### Outcrop -The `outcrop` extension allows you to extend the image in 64 pixel -increments in any dimension. You can apply the module to any image -previously-generated by InvokeAI. Note that it will **not** work with -arbitrary photographs or Stable Diffusion images created by other -implementations. +The `outcrop` extension gives you a convenient `!fix` postprocessing +command that allows you to extend a previously-generated image in 64 +pixel increments in any direction. You can apply the module to any +image previously-generated by InvokeAI. Note that it works with +arbitrary PNG photographs, but not currently with JPG or other +formats. Outcropping is particularly effective when combined with the +[runwayML custom inpainting +model](INPAINTING.md#using-the-runwayml-inpainting-model). Consider this image: @@ -64,42 +103,3 @@ you'll get a slightly different result. You can run it repeatedly until you get an image you like. Unfortunately `!fix` does not currently respect the `-n` (`--iterations`) argument. -## Outpaint - -The `outpaint` extension does the same thing, but with subtle -differences. Starting with the same image, here is how we would add an -additional 64 pixels to the top of the image: - -```bash -invoke> !fix images/curly.png --out_direction top 64 -``` - -(you can abbreviate `--out_direction` as `-D`. - -The result is shown here: - -
-![curly_woman_outpaint](../assets/outpainting/curly-outpaint.png) -
- -Although the effect is similar, there are significant differences from -outcropping: - -- You can only specify one direction to extend at a time. -- The image is **not** resized. Instead, the image is shifted by the specified -number of pixels. If you look carefully, you'll see that less of the lady's -torso is visible in the image. -- Because the image dimensions remain the same, there's no rounding -to multiples of 64. -- Attempting to outpaint larger areas will frequently give rise to ugly - ghosting effects. -- For best results, try increasing the step number. -- If you don't specify a pixel value in `-D`, it will default to half - of the whole image, which is likely not what you want. - -!!! tip - - Neither `outpaint` nor `outcrop` are perfect, but we continue to tune - and improve them. If one doesn't work, try the other. You may also - wish to experiment with other `img2img` arguments, such as `-C`, `-f` - and `-s`. diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index b5ef26858b..b1152ca435 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -45,7 +45,7 @@ Here's a prompt that depicts what it does. original prompt: -`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180` +`#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
![step1](../assets/negative_prompt_walkthru/step1.png) @@ -84,6 +84,109 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h --- +## **Prompt Syntax Features** + +The InvokeAI prompting language has the following features: + +### Attention weighting +Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term). + +The following syntax is recognised: + * single words without parentheses: `a tall thin man picking apricots+` + * single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-` + * more effect with more symbols `a tall thin man (picking apricots)++` + * nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`) + * all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.) + * attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder + +You can use this to increase or decrease the amount of something. Starting from this prompt of `a man picking apricots from a tree`, let's see what happens if we increase and decrease how much attention we want Stable Diffusion to pay to the word `apricots`: + +![an AI generated image of a man picking apricots from a tree](../assets/prompt_syntax/apricots-0.png) + +Using `-` to reduce apricot-ness: + +| `a man picking apricots- from a tree` | `a man picking apricots-- from a tree` | `a man picking apricots--- from a tree` | +| -- | -- | -- | +| ![an AI generated image of a man picking apricots from a tree, with smaller apricots](../assets/prompt_syntax/apricots--1.png) | ![an AI generated image of a man picking apricots from a tree, with even smaller and fewer apricots](../assets/prompt_syntax/apricots--2.png) | ![an AI generated image of a man picking apricots from a tree, with very few very small apricots](../assets/prompt_syntax/apricots--3.png) | + +Using `+` to increase apricot-ness: + +| `a man picking apricots+ from a tree` | `a man picking apricots++ from a tree` | `a man picking apricots+++ from a tree` | `a man picking apricots++++ from a tree` | `a man picking apricots+++++ from a tree` | +| -- | -- | -- | -- | -- | +| ![an AI generated image of a man picking apricots from a tree, with larger, more vibrant apricots](../assets/prompt_syntax/apricots-1.png) | ![an AI generated image of a man picking apricots from a tree with even larger, even more vibrant apricots](../assets/prompt_syntax/apricots-2.png) | ![an AI generated image of a man picking apricots from a tree, but the man has been replaced by a pile of apricots](../assets/prompt_syntax/apricots-3.png) | ![an AI generated image of a man picking apricots from a tree, but the man has been replaced by a mound of giant melting-looking apricots](../assets/prompt_syntax/apricots-4.png) | ![an AI generated image of a man picking apricots from a tree, but the man and the leaves and parts of the ground have all been replaced by giant melting-looking apricots](../assets/prompt_syntax/apricots-5.png) | + +You can also change the balance between different parts of a prompt. For example, below is a `mountain man`: + +![an AI generated image of a mountain man](../assets/prompt_syntax/mountain-man.png) + +And here he is with more mountain: + +| `mountain+ man` | `mountain++ man` | `mountain+++ man` | +| -- | -- | -- | +| ![](../assets/prompt_syntax/mountain1-man.png) | ![](../assets/prompt_syntax/mountain2-man.png) | ![](../assets/prompt_syntax/mountain3-man.png) | + +Or, alternatively, with more man: + +| `mountain man+` | `mountain man++` | `mountain man+++` | `mountain man++++` | +| -- | -- | -- | -- | +| ![](../assets/prompt_syntax/mountain-man1.png) | ![](../assets/prompt_syntax/mountain-man2.png) | ![](../assets/prompt_syntax/mountain-man3.png) | ![](../assets/prompt_syntax/mountain-man4.png) | + +### Blending between prompts + +* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` +* The existing prompt blending using `:` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax. +* Attention weights can be nested inside blends. +* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output. + +See the section below on "Prompt Blending" for more information about how this works. + +### Cross-Attention Control ('prompt2prompt') + +Sometimes an image you generate is almost right, and you just want to +change one detail without affecting the rest. You could use a photo editor and inpainting +to overpaint the area, but that's a pain. Here's where `prompt2prompt` +comes in handy. + +Generate an image with a given prompt, record the seed of the image, +and then use the `prompt2prompt` syntax to substitute words in the +original prompt for words in a new prompt. This works for `img2img` as well. + +* `a ("fluffy cat").swap("smiling dog") eating a hotdog`. + * quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. + * for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`. +* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand. + * Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form. + * The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image). + * For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect. +* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped. + * `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. + + + +The `prompt2prompt` code is based off [bloc97's +colab](https://github.com/bloc97/CrossAttentionControl). + +Note that `prompt2prompt` is not currently working with the runwayML +inpainting model, and may never work due to the way this model is set +up. If you attempt to use `prompt2prompt` you will get the original +image back. However, since this model is so good at inpainting, a +good substitute is to use the `clipseg` text masking option: + +``` +invoke> a fluffy cat eating a hotdot +Outputs: +[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog +invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat +``` + +### Escaping parantheses () and speech marks "" + +If the model you are using has parentheses () or speech marks "" as +part of its syntax, you will need to "escape" these using a backslash, +so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt +parser will attempt to interpret the parentheses as part of the prompt +syntax and it will get confused. + ## **Prompt Blending** You may blend together different sections of the prompt to explore the diff --git a/docs/installation/INSTALL_DOCKER.md b/docs/installation/INSTALL_DOCKER.md index eb2e2ab39f..50c3d89c81 100644 --- a/docs/installation/INSTALL_DOCKER.md +++ b/docs/installation/INSTALL_DOCKER.md @@ -36,20 +36,6 @@ another environment with NVIDIA GPUs on-premises or in the cloud. ### Prerequisites -#### Get the data files - -Go to -[Hugging Face](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original), -and click "Access repository" to Download the model file `sd-v1-4.ckpt` (~4 GB) -to `~/Downloads`. You'll need to create an account but it's quick and free. - -Also download the face restoration model. - -```Shell -cd ~/Downloads -wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -``` - #### Install [Docker](https://github.com/santisbon/guides#docker) On the Docker Desktop app, go to Preferences, Resources, Advanced. Increase the @@ -57,86 +43,61 @@ CPUs and Memory to avoid this [Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to increase Swap and Disk image size too. +#### Get a Huggingface-Token + +Go to [Hugging Face](https://huggingface.co/settings/tokens), create a token and +temporary place it somewhere like a open texteditor window (but dont save it!, +only keep it open, we need it in the next step) + ### Setup Set the fork you want to use and other variables. -```Shell -TAG_STABLE_DIFFUSION="santisbon/stable-diffusion" -PLATFORM="linux/arm64" -GITHUB_STABLE_DIFFUSION="-b orig-gfpgan https://github.com/santisbon/stable-diffusion.git" -REQS_STABLE_DIFFUSION="requirements-linux-arm64.txt" -CONDA_SUBDIR="osx-arm64" +!!! tip -echo $TAG_STABLE_DIFFUSION -echo $PLATFORM -echo $GITHUB_STABLE_DIFFUSION -echo $REQS_STABLE_DIFFUSION -echo $CONDA_SUBDIR + I preffer to save my env vars + in the repository root in a `.env` (or `.envrc`) file to automatically re-apply + them when I come back. + +The build- and run- scripts contain default values for almost everything, +besides the [Hugging Face Token](https://huggingface.co/settings/tokens) you +created in the last step. + +Some Suggestions of variables you may want to change besides the Token: + +| Environment-Variable | Description | +| ------------------------------------------------------------------- | ------------------------------------------------------------------------ | +| `HUGGINGFACE_TOKEN="hg_aewirhghlawrgkjbarug2"` | This is the only required variable, without you can't get the checkpoint | +| `ARCH=aarch64` | if you are using a ARM based CPU | +| `INVOKEAI_TAG=yourname/invokeai:latest` | the Container Repository / Tag which will be used | +| `INVOKEAI_CONDA_ENV_FILE=environment-linux-aarch64.yml` | since environment.yml wouldn't work with aarch | +| `INVOKEAI_GIT="-b branchname https://github.com/username/reponame"` | if you want to use your own fork | + +#### Build the Image + +I provided a build script, which is located in `docker-build/build.sh` but still +needs to be executed from the Repository root. + +```bash +docker-build/build.sh ``` -Create a Docker volume for the downloaded model files. +The build Script not only builds the container, but also creates the docker +volume if not existing yet, or if empty it will just download the models. When +it is done you can run the container via the run script -```Shell -docker volume create my-vol +```bash +docker-build/run.sh ``` -Copy the data files to the Docker volume using a lightweight Linux container. -We'll need the models at run time. You just need to create the container with -the mountpoint; no need to run this dummy container. +When used without arguments, the container will start the website and provide +you the link to open it. But if you want to use some other parameters you can +also do so. -```Shell -cd ~/Downloads # or wherever you saved the files +!!! warning "Deprecated" -docker create --platform $PLATFORM --name dummy --mount source=my-vol,target=/data alpine - -docker cp sd-v1-4.ckpt dummy:/data -docker cp GFPGANv1.4.pth dummy:/data -``` - -Get the repo and download the Miniconda installer (we'll need it at build time). -Replace the URL with the version matching your container OS and the architecture -it will run on. - -```Shell -cd ~ -git clone $GITHUB_STABLE_DIFFUSION - -cd stable-diffusion/docker-build -chmod +x entrypoint.sh -wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O anaconda.sh && chmod +x anaconda.sh -``` - -Build the Docker image. Give it any tag `-t` that you want. -Choose the Linux container's host platform: x86-64/Intel is `amd64`. Apple -silicon is `arm64`. If deploying the container to the cloud to leverage powerful -GPU instances you'll be on amd64 hardware but if you're just trying this out -locally on Apple silicon choose arm64. -The application uses libraries that need to match the host environment so use -the appropriate requirements file. -Tip: Check that your shell session has the env variables set above. - -```Shell -docker build -t $TAG_STABLE_DIFFUSION \ ---platform $PLATFORM \ ---build-arg gsd=$GITHUB_STABLE_DIFFUSION \ ---build-arg rsd=$REQS_STABLE_DIFFUSION \ ---build-arg cs=$CONDA_SUBDIR \ -. -``` - -Run a container using your built image. -Tip: Make sure you've created and populated the Docker volume (above). - -```Shell -docker run -it \ ---rm \ ---platform $PLATFORM \ ---name stable-diffusion \ ---hostname stable-diffusion \ ---mount source=my-vol,target=/data \ -$TAG_STABLE_DIFFUSION -``` + From here on it is the rest of the previous Docker-Docs, which will still + provide usefull informations for one or the other. ## Usage (time to have fun) @@ -240,7 +201,8 @@ server with: python3 scripts/invoke.py --full_precision --web ``` -If it's running on your Mac point your Mac web browser to http://127.0.0.1:9090 +If it's running on your Mac point your Mac web browser to + Press Control-C at the command line to stop the web server. diff --git a/environment-linux-aarch64.yml b/environment-linux-aarch64.yml new file mode 100644 index 0000000000..f430770c4d --- /dev/null +++ b/environment-linux-aarch64.yml @@ -0,0 +1,44 @@ +name: invokeai +channels: + - pytorch + - conda-forge +dependencies: + - python>=3.9 + - pip>=20.3 + - cudatoolkit + - pytorch + - torchvision + - numpy=1.19 + - imageio=2.9.0 + - opencv=4.6.0 + - pillow=8.* + - flask=2.1.* + - flask_cors=3.0.10 + - flask-socketio=5.3.0 + - send2trash=1.8.0 + - eventlet + - albumentations=0.4.3 + - pudb=2019.2 + - imageio-ffmpeg=0.4.2 + - pytorch-lightning=1.7.7 + - streamlit + - einops=0.3.0 + - kornia=0.6 + - torchmetrics=0.7.0 + - transformers=4.21.3 + - torch-fidelity=0.3.0 + - tokenizers>=0.11.1,!=0.11.3,<0.13 + - pip: + - omegaconf==2.1.1 + - realesrgan==0.2.5.0 + - test-tube>=0.7.5 + - pyreadline3 + - dependency_injector==4.40.0 + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion + - -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan + - -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg + - -e . +variables: + PYTORCH_ENABLE_MPS_FALLBACK: 1 diff --git a/ldm/generate.py b/ldm/generate.py index 3ede1710e1..7f7bd43397 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1,5 +1,5 @@ # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) - +import pyparsing # Derived from source code carrying the following copyrights # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors @@ -24,6 +24,7 @@ from PIL import Image, ImageOps from torch import nn from pytorch_lightning import seed_everything, logging +from ldm.invoke.prompt_parser import PromptParser from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter from ldm.invoke.args import metadata_from_png from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision -from ldm.invoke.conditioning import get_uc_and_c +from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale @@ -179,6 +180,7 @@ class Generate: self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None self.safety_checker = None + self.karras_max = None # Note that in previous versions, there was an option to pass the # device to Generate(). However the device was then ignored, so @@ -269,10 +271,12 @@ class Generate: variation_amount = 0.0, threshold = 0.0, perlin = 0.0, + karras_max = None, # these are specific to img2img and inpaint init_img = None, init_mask = None, text_mask = None, + invert_mask = False, fit = False, strength = None, init_color = None, @@ -293,6 +297,13 @@ class Generate: catch_interrupts = False, hires_fix = False, use_mps_noise = False, + # Seam settings for outpainting + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + force_outpaint: bool = False, **args, ): # eat up additional cruft """ @@ -310,6 +321,7 @@ class Generate: init_img // path to an initial image init_mask // path to a mask for the initial image text_mask // a text string that will be used to guide clipseg generation of the init_mask + invert_mask // boolean, if true invert the mask strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) @@ -350,7 +362,8 @@ class Generate: strength = strength or self.strength self.seed = seed self.log_tokenization = log_tokenization - self.step_callback = step_callback + self.step_callback = step_callback + self.karras_max = karras_max with_variations = [] if with_variations is None else with_variations # will instantiate the model or return it from cache @@ -395,6 +408,11 @@ class Generate: self.sampler_name = sampler_name self._set_sampler() + # bit of a hack to change the cached sampler's karras threshold to + # whatever the user asked for + if karras_max is not None and isinstance(self.sampler,KSampler): + self.sampler.adjust_settings(karras_max=karras_max) + tic = time.time() if self._has_cuda(): torch.cuda.reset_peak_memory_stats() @@ -404,7 +422,7 @@ class Generate: mask_image = None try: - uc, c = get_uc_and_c( + uc, c, extra_conditioning_info = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -417,19 +435,12 @@ class Generate: height, fit=fit, text_mask=text_mask, + invert_mask=invert_mask, + force_outpaint=force_outpaint, ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - if (init_image is not None) and (mask_image is not None): - generator = self._make_inpaint() - elif (embiggen != None or embiggen_tiles != None): - generator = self._make_embiggen() - elif init_image is not None: - generator = self._make_img2img() - elif hires_fix: - generator = self._make_txt2img2img() - else: - generator = self._make_txt2img() + generator = self.select_generator(init_image, mask_image, embiggen, hires_fix) generator.set_variation( self.seed, variation_amount, with_variations @@ -448,7 +459,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c, extra_conditioning_info), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -464,7 +475,13 @@ class Generate: embiggen_tiles=embiggen_tiles, inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, - safety_checker=checker + safety_checker=checker, + seam_size = seam_size, + seam_blur = seam_blur, + seam_strength = seam_strength, + seam_steps = seam_steps, + tile_size = tile_size, + force_outpaint = force_outpaint ) if init_color: @@ -481,14 +498,14 @@ class Generate: save_original = save_original, image_callback = image_callback) - except RuntimeError as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') except KeyboardInterrupt: if catch_interrupts: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + except RuntimeError as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') @@ -545,7 +562,7 @@ class Generate: # try to reuse the same filename prefix as the original file. # we take everything up to the first period prefix = None - m = re.match('^([^.]+)\.',os.path.basename(image_path)) + m = re.match(r'^([^.]+)\.',os.path.basename(image_path)) if m: prefix = m.groups()[0] @@ -553,7 +570,8 @@ class Generate: image = Image.open(image_path) # used by multiple postfixers - uc, c = get_uc_and_c( + # todo: cross-attention control + uc, c, _ = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=opt.skip_normalize, log_tokens =opt.log_tokenization @@ -598,10 +616,9 @@ class Generate: elif tool == 'embiggen': # fetch the metadata from the image - generator = self._make_embiggen() + generator = self.select_generator(embiggen=True) opt.strength = 0.40 print(f'>> Setting img2img strength to {opt.strength} for happy embiggening') - # embiggen takes a image path (sigh) generator.generate( prompt, sampler = self.sampler, @@ -635,6 +652,32 @@ class Generate: print(f'* postprocessing tool {tool} is not yet supported') return None + def select_generator( + self, + init_image:Image.Image=None, + mask_image:Image.Image=None, + embiggen:bool=False, + hires_fix:bool=False, + force_outpaint:bool=False, + ): + inpainting_model_in_use = self.sampler.uses_inpainting_model() + + if hires_fix: + return self._make_txt2img2img() + + if embiggen is not None: + return self._make_embiggen() + + if inpainting_model_in_use: + return self._make_omnibus() + + if ((init_image is not None) and (mask_image is not None)) or force_outpaint: + return self._make_inpaint() + + if init_image is not None: + return self._make_img2img() + + return self._make_txt2img() def _make_images( self, @@ -644,6 +687,8 @@ class Generate: height, fit=False, text_mask=None, + invert_mask=False, + force_outpaint=False, ): init_image = None init_mask = None @@ -657,7 +702,7 @@ class Generate: # if image has a transparent area and no mask was provided, then try to generate mask if self._has_transparency(image): - self._transparency_check_and_warning(image, mask) + self._transparency_check_and_warning(image, mask, force_outpaint) init_mask = self._create_init_mask(image, width, height, fit=fit) if (image.width * image.height) > (self.width * self.height) and self.size_matters: @@ -673,8 +718,12 @@ class Generate: elif text_mask: init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) + if invert_mask: + init_mask = ImageOps.invert(init_mask) + return init_image,init_mask + # lots o' repeated code here! Turn into a make_func() def _make_base(self): if not self.generators.get('base'): from ldm.invoke.generator import Generator @@ -685,6 +734,7 @@ class Generate: if not self.generators.get('img2img'): from ldm.invoke.generator.img2img import Img2Img self.generators['img2img'] = Img2Img(self.model, self.precision) + self.generators['img2img'].free_gpu_mem = self.free_gpu_mem return self.generators['img2img'] def _make_embiggen(self): @@ -713,6 +763,15 @@ class Generate: self.generators['inpaint'] = Inpaint(self.model, self.precision) return self.generators['inpaint'] + # "omnibus" supports the runwayML custom inpainting model, which does + # txt2img, img2img and inpainting using slight variations on the same code + def _make_omnibus(self): + if not self.generators.get('omnibus'): + from ldm.invoke.generator.omnibus import Omnibus + self.generators['omnibus'] = Omnibus(self.model, self.precision) + self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem + return self.generators['omnibus'] + def load_model(self): ''' preload model identified in self.model_name @@ -839,6 +898,8 @@ class Generate: def sample_to_image(self, samples): return self._make_base().sample_to_image(samples) + # very repetitive code - can this be simplified? The KSampler names are + # consistent, at least def _set_sampler(self): msg = f'>> Setting Sampler to {self.sampler_name}' if self.sampler_name == 'plms': @@ -846,15 +907,11 @@ class Generate: elif self.sampler_name == 'ddim': self.sampler = DDIMSampler(self.model, device=self.device) elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler( - self.model, 'dpm_2_ancestral', device=self.device - ) + self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) elif self.sampler_name == 'k_dpm_2': self.sampler = KSampler(self.model, 'dpm_2', device=self.device) elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler( - self.model, 'euler_ancestral', device=self.device - ) + self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) elif self.sampler_name == 'k_euler': self.sampler = KSampler(self.model, 'euler', device=self.device) elif self.sampler_name == 'k_heun': @@ -888,8 +945,9 @@ class Generate: image = ImageOps.exif_transpose(image) return image - def _create_init_image(self, image, width, height, fit=True): - image = image.convert('RGB') + def _create_init_image(self, image: Image.Image, width, height, fit=True): + if image.mode != 'RGBA': + image = image.convert('RGBA') image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) return image @@ -954,11 +1012,11 @@ class Generate: colored += 1 return colored == 0 - def _transparency_check_and_warning(self,image, mask): + def _transparency_check_and_warning(self,image, mask, force_outpaint=False): if not mask: print( '>> Initial image has transparent areas. Will inpaint in these regions.') - if self._check_for_erasure(image): + if (not force_outpaint) and self._check_for_erasure(image): print( '>> WARNING: Colors underneath the transparent region seem to have been erased.\n', '>> Inpainting will be suboptimal. Please preserve the colors when making\n', diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index b24620df15..b509362644 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -83,16 +83,16 @@ with metadata_from_png(): import argparse from argparse import Namespace, RawTextHelpFormatter import pydoc -import shlex import json import hashlib import os import re +import shlex import copy import base64 import functools import ldm.invoke.pngwriter -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts SAMPLER_CHOICES = [ 'ddim', @@ -169,28 +169,31 @@ class Args(object): def parse_cmd(self,cmd_string): '''Parse a invoke>-style command string ''' - command = cmd_string.replace("'", "\\'") - try: - elements = shlex.split(command) - elements = [x.replace("\\'","'") for x in elements] - except ValueError: - import sys, traceback - print(traceback.format_exc(), file=sys.stderr) - return - switches = [''] - switches_started = False - - for element in elements: - if element[0] == '-' and not switches_started: - switches_started = True - if switches_started: - switches.append(element) + # handle the case in which the first token is a switch + if cmd_string.startswith('-'): + prompt = '' + switches = cmd_string + # handle the case in which the prompt is enclosed by quotes + elif cmd_string.startswith('"'): + a = shlex.split(cmd_string,comments=True) + prompt = a[0] + switches = shlex.join(a[1:]) + else: + # no initial quote, so get everything up to the first thing + # that looks like a switch + if cmd_string.startswith('-'): + prompt = '' + switches = cmd_string else: - switches[0] += element - switches[0] += ' ' - switches[0] = switches[0][: len(switches[0]) - 1] + match = re.match('^(.+?)\s(--?[a-zA-Z].+)',cmd_string) + if match: + prompt,switches = match.groups() + else: + prompt = cmd_string + switches = '' try: - self._cmd_switches = self._cmd_parser.parse_args(switches) + self._cmd_switches = self._cmd_parser.parse_args(shlex.split(switches,comments=True)) + setattr(self._cmd_switches,'prompt',prompt) return self._cmd_switches except: return None @@ -211,12 +214,16 @@ class Args(object): a = vars(self) a.update(kwargs) switches = list() - switches.append(f'"{a["prompt"]}"') + prompt = a['prompt'] + prompt.replace('"','\\"') + switches.append(prompt) switches.append(f'-s {a["steps"]}') switches.append(f'-S {a["seed"]}') switches.append(f'-W {a["width"]}') switches.append(f'-H {a["height"]}') switches.append(f'-C {a["cfg_scale"]}') + if a['karras_max'] is not None: + switches.append(f'--karras_max {a["karras_max"]}') if a['perlin'] > 0: switches.append(f'--perlin {a["perlin"]}') if a['threshold'] > 0: @@ -568,10 +575,17 @@ class Args(object): ) render_group = parser.add_argument_group('General rendering') img2img_group = parser.add_argument_group('Image-to-image and inpainting') + inpainting_group = parser.add_argument_group('Inpainting') + outpainting_group = parser.add_argument_group('Outpainting and outcropping') variation_group = parser.add_argument_group('Creating and combining variations') postprocessing_group = parser.add_argument_group('Post-processing') special_effects_group = parser.add_argument_group('Special effects') - render_group.add_argument('prompt') + deprecated_group = parser.add_argument_group('Deprecated options') + render_group.add_argument( + '--prompt', + default='', + help='prompt string', + ) render_group.add_argument( '-s', '--steps', @@ -689,7 +703,13 @@ class Args(object): default=6, choices=range(0,10), dest='png_compression', - help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.' + help='level of PNG compression, from 0 (none) to 9 (maximum). [6]' + ) + render_group.add_argument( + '--karras_max', + type=int, + default=None, + help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]." ) img2img_group.add_argument( '-I', @@ -697,12 +717,6 @@ class Args(object): type=str, help='Path to input image for img2img mode (supersedes width and height)', ) - img2img_group.add_argument( - '-M', - '--init_mask', - type=str, - help='Path to input mask for inpainting mode (supersedes width and height)', - ) img2img_group.add_argument( '-tm', '--text_mask', @@ -730,29 +744,68 @@ class Args(object): help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', default=0.75, ) - img2img_group.add_argument( - '-D', - '--out_direction', - nargs='+', + inpainting_group.add_argument( + '-M', + '--init_mask', type=str, - metavar=('direction', 'pixels'), - help='Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size' + help='Path to input mask for inpainting mode (supersedes width and height)', ) - img2img_group.add_argument( - '-c', - '--outcrop', - nargs='+', - type=str, - metavar=('direction','pixels'), - help='Outcrop the image with one or more direction/pixel pairs: -c top 64 bottom 128 left 64 right 64', + inpainting_group.add_argument( + '--invert_mask', + action='store_true', + help='Invert the mask', ) - img2img_group.add_argument( + inpainting_group.add_argument( '-r', '--inpaint_replace', type=float, default=0.0, help='when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)', ) + outpainting_group.add_argument( + '-c', + '--outcrop', + nargs='+', + type=str, + metavar=('direction','pixels'), + help='Outcrop the image with one or more direction/pixel pairs: e.g. -c top 64 bottom 128 left 64 right 64', + ) + outpainting_group.add_argument( + '--force_outpaint', + action='store_true', + default=False, + help='Force outpainting if you have no inpainting mask to pass', + ) + outpainting_group.add_argument( + '--seam_size', + type=int, + default=0, + help='When outpainting, size of the mask around the seam between original and outpainted image', + ) + outpainting_group.add_argument( + '--seam_blur', + type=int, + default=0, + help='When outpainting, the amount to blur the seam inwards', + ) + outpainting_group.add_argument( + '--seam_strength', + type=float, + default=0.7, + help='When outpainting, the img2img strength to use when filling the seam. Values around 0.7 work well', + ) + outpainting_group.add_argument( + '--seam_steps', + type=int, + default=10, + help='When outpainting, the number of steps to use to fill the seam. Low values (~10) work well', + ) + outpainting_group.add_argument( + '--tile_size', + type=int, + default=32, + help='When outpainting, the tile size to use for filling outpaint areas', + ) postprocessing_group.add_argument( '-ft', '--facetool', @@ -836,7 +889,14 @@ class Args(object): dest='use_mps_noise', help='Simulate noise on M1 systems to get the same results' ) - + deprecated_group.add_argument( + '-D', + '--out_direction', + nargs='+', + type=str, + metavar=('direction', 'pixels'), + help='Older outcropping system. Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size' + ) return parser def format_metadata(**kwargs): @@ -872,7 +932,7 @@ def metadata_dumps(opt, # remove any image keys not mentioned in RFC #266 rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', - 'cfg_scale','threshold','perlin','fnformat', 'step_number','width','height','extra','strength', + 'cfg_scale','threshold','perlin','step_number','width','height','extra','strength', 'init_img','init_mask','facetool','facetool_strength','upscale'] rfc_dict ={} @@ -923,6 +983,23 @@ def metadata_dumps(opt, return metadata +@functools.lru_cache(maxsize=50) +def args_from_png(png_file_path) -> list[Args]: + ''' + Given the path to a PNG file created by invoke.py, + retrieves a list of Args objects containing the image + data. + ''' + try: + meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path) + except AttributeError: + return [legacy_metadata_load({},png_file_path)] + + try: + return metadata_loads(meta) + except: + return [legacy_metadata_load(meta,png_file_path)] + @functools.lru_cache(maxsize=50) def metadata_from_png(png_file_path) -> Args: ''' @@ -930,11 +1007,8 @@ def metadata_from_png(png_file_path) -> Args: an Args object containing the image metadata. Note that this returns a single Args object, not multiple. ''' - meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path) - if 'sd-metadata' in meta and len(meta['sd-metadata'])>0 : - return metadata_loads(meta)[0] - else: - return legacy_metadata_load(meta,png_file_path) + args_list = args_from_png(png_file_path) + return args_list[0] def dream_cmd_from_png(png_file_path): opt = metadata_from_png(png_file_path) @@ -949,7 +1023,7 @@ def metadata_loads(metadata) -> list: ''' results = [] try: - if 'grid' in metadata['sd-metadata']: + if 'images' in metadata['sd-metadata']: images = metadata['sd-metadata']['images'] else: images = [metadata['sd-metadata']['image']] diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..9b775823aa 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -4,107 +4,191 @@ weighted subprompts. Useful function exports: -get_uc_and_c() get the conditioned and unconditioned latent +get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control split_weighted_subpromopts() split subprompts, normalize and weight them log_tokenization() print out colour-coded tokens and warn if truncated ''' import re +from difflib import SequenceMatcher +from typing import Union + import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment +from ..models.diffusion.cross_attention_control import CrossAttentionControl +from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - uc = model.get_learned_conditioning([unconditioned_words]) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize + parsed_prompt: Union[FlattenedPrompt, Blend] = None + legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) + if legacy_blend is not None: + parsed_prompt = legacy_blend + else: + # we don't support conjunctions for now + parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] + + parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] + print(f">> Parsed prompt to {parsed_prompt}") + + conditioning = None + cac_args:CrossAttentionControl.Arguments = None + + if type(parsed_prompt) is Blend: + blend: Blend = parsed_prompt + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( + (embeddings_to_blend, this_embedding)) + conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), + blend.weights, + normalize=blend.normalize_weights) + else: + flattened_prompt: FlattenedPrompt = parsed_prompt + wants_cross_attention_control = type(flattened_prompt) is not Blend \ + and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) + if wants_cross_attention_control: + original_prompt = FlattenedPrompt() + edited_prompt = FlattenedPrompt() + # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed + original_token_count = 0 + edited_token_count = 0 + edit_opcodes = [] + edit_options = [] + for fragment in flattened_prompt.children: + if type(fragment) is CrossAttentionControlSubstitute: + original_prompt.append(fragment.original) + edited_prompt.append(fragment.edited) + + to_replace_token_count = get_tokens_length(model, fragment.original) + replacement_token_count = get_tokens_length(model, fragment.edited) + edit_opcodes.append(('replace', + original_token_count, original_token_count + to_replace_token_count, + edited_token_count, edited_token_count + replacement_token_count + )) + original_token_count += to_replace_token_count + edited_token_count += replacement_token_count + edit_options.append(fragment.options) + #elif type(fragment) is CrossAttentionControlAppend: + # edited_prompt.append(fragment.fragment) + else: + # regular fragment + original_prompt.append(fragment) + edited_prompt.append(fragment) + + count = get_tokens_length(model, [fragment]) + edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count)) + edit_options.append(None) + original_token_count += count + edited_token_count += count + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens) + # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of + # subsequent tokens when there is >1 edit and earlier edits change the total token count. + # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the + # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra + # token 'smiling' in the inactive 'cat' edit. + # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens) + + conditioning = original_embeddings + edited_conditioning = edited_embeddings + #print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) + cac_args = CrossAttentionControl.Arguments( + edited_conditioning = edited_conditioning, + edit_opcodes = edit_opcodes, + edit_options = edit_options + ) + else: + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) + + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) + if isinstance(conditioning, dict): + # hybrid conditioning is in play + unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) + if cac_args is not None: + print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") + cac_args = None + + return ( + unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( + cross_attention_control_args=cac_args + ) ) - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - c = torch.add( - c, - model.get_learned_conditioning([subprompt]), - alpha=weight, - ) - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) - return (uc, c) -def split_weighted_subprompts(text, skip_normalize=False)->list: - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional - \s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group - """, re.VERBOSE) - parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( - match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] - if skip_normalize: - return parsed_prompts - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - if weight_sum == 0: - print( - "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") - equal_weight = 1 / max(len(parsed_prompts), 1) - return [(x[0], equal_weight) for x in parsed_prompts] - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] +def build_token_edit_opcodes(original_tokens, edited_tokens): + original_tokens = original_tokens.cpu().numpy()[0] + edited_tokens = edited_tokens.cpu().numpy()[0] -# shows how the prompt is tokenized -# usually tokens have '' to indicate end-of-word, -# but for readability it has been replaced with ' ' -def log_tokenization(text, model, log=False, weight=1): - if not log: - return - tokens = model.cond_stage_model.tokenizer._tokenize(text) - tokenized = "" - discarded = "" - usedTokens = 0 - totalTokens = len(tokens) - for i in range(0, totalTokens): - token = tokens[i].replace('', ' ') - # alternate color - s = (usedTokens % 6) + 1 - if i < model.cond_stage_model.max_length: - tokenized = tokenized + f"\x1b[0;3{s};40m{token}" - usedTokens += 1 - else: # over max token length - discarded = discarded + f"\x1b[0;3{s};40m{token}" - print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") - if discarded != "": - print( - f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" - ) + return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() + +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False): + if type(flattened_prompt) is not FlattenedPrompt: + raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") + fragments = [x.text for x in flattened_prompt.children] + weights = [x.weight for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) + if not flattened_prompt.is_empty and log_tokens: + start_token = model.cond_stage_model.tokenizer.bos_token_id + end_token = model.cond_stage_model.tokenizer.eos_token_id + tokens_list = tokens[0].tolist() + if tokens_list[0] == start_token: + tokens_list[0] = '' + try: + first_end_token_index = tokens_list.index(end_token) + tokens_list[first_end_token_index] = '' + tokens_list = tokens_list[:first_end_token_index+1] + except ValueError: + pass + + print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}") + + return embeddings, tokens + +def get_tokens_length(model, fragments: list[Fragment]): + fragment_texts = [x.text for x in fragments] + tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) + return sum([len(x) for x in tokens]) + +def flatten_hybrid_conditioning(uncond, cond): + ''' + This handles the choice between a conditional conditioning + that is a tensor (used by cross attention) vs one that has additional + dimensions as well, as used by 'hybrid' + ''' + assert isinstance(uncond, dict) + assert isinstance(cond, dict) + cond_flattened = dict() + for k in cond: + if isinstance(cond[k], list): + cond_flattened[k] = [ + torch.cat([uncond[k][i], cond[k][i]]) + for i in range(len(cond[k])) + ] + else: + cond_flattened[k] = torch.cat([uncond[k], cond[k]]) + return uncond, cond_flattened + + diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index d42013ea73..77b92d693d 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -6,6 +6,7 @@ import torch import numpy as np import random import os +import traceback from tqdm import tqdm, trange from PIL import Image, ImageFilter from einops import rearrange, repeat @@ -28,7 +29,8 @@ class Generator(): self.threshold = 0 self.variation_amount = 0 self.with_variations = [] - self.use_mps_noise = False + self.use_mps_noise = False + self.free_gpu_mem = None # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self,prompt,**kwargs): @@ -43,14 +45,15 @@ class Generator(): self.variation_amount = variation_amount self.with_variations = with_variations - def generate(self,prompt,init_image,width,height,iterations=1,seed=None, + def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, **kwargs): scope = choose_autocast(self.precision) self.safety_checker = safety_checker - make_image = self.get_make_image( + make_image = self.get_make_image( prompt, + sampler = sampler, init_image = init_image, width = width, height = height, @@ -59,12 +62,14 @@ class Generator(): perlin = perlin, **kwargs ) - results = [] seed = seed if seed is not None else self.new_seed() first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) - with scope(self.model.device.type), self.model.ema_scope(): + + # There used to be an additional self.model.ema_scope() here, but it breaks + # the inpaint-1.5 model. Not sure what it did.... ? + with scope(self.model.device.type): for n in trange(iterations, desc='Generating'): x_T = None if self.variation_amount > 0: @@ -79,7 +84,8 @@ class Generator(): try: x_T = self.get_noise(width,height) except: - pass + print('** An error occurred while getting initial noise **') + print(traceback.format_exc()) image = make_image(x_T) @@ -95,10 +101,10 @@ class Generator(): return results - def sample_to_image(self,samples): + def sample_to_image(self,samples)->Image.Image: """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it + Given samples returned from a sampler, converts + it into a PIL Image """ x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py index 53fbde68cf..dc6af35a6c 100644 --- a/ldm/invoke/generator/embiggen.py +++ b/ldm/invoke/generator/embiggen.py @@ -21,6 +21,7 @@ class Embiggen(Generator): def generate(self,prompt,iterations=1,seed=None, image_callback=None, step_callback=None, **kwargs): + scope = choose_autocast(self.precision) make_image = self.get_make_image( prompt, @@ -63,6 +64,8 @@ class Embiggen(Generator): Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image Return value depends on the seed at the time you call it """ + assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models" + # Construct embiggen arg array, and sanity check arguments if embiggen == None: # embiggen can also be called with just embiggen_tiles embiggen = [1.0] # If not specified, assume no scaling diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 613f1aca31..1981b4eacb 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -10,11 +10,12 @@ from PIL import Image from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent class Img2Img(Generator): def __init__(self, model, precision): super().__init__(model, precision) - self.init_latent = None # by get_noise() + self.init_latent = None # by get_noise() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): @@ -29,7 +30,7 @@ class Img2Img(Generator): ) if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image) + init_image = self._image_to_tensor(init_image.convert('RGB')) scope = choose_autocast(self.precision) with scope(self.model.device.type): @@ -38,7 +39,7 @@ class Img2Img(Generator): ) # move to latent space t_enc = int(strength * steps) - uc, c = conditioning + uc, c, extra_conditioning_info = conditioning def make_image(x_T): # encode (scaled latent) @@ -55,7 +56,9 @@ class Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler + init_latent = self.init_latent, # changes how noising is performed in ksampler + extra_conditioning_info = extra_conditioning_info, + all_timesteps_count = steps ) return self.sample_to_image(samples) @@ -77,7 +80,10 @@ class Img2Img(Generator): def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) + if len(image.shape) == 2: # 'L' image, as in a mask + image = image[None,None] + else: # 'RGB' image + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) if normalize: image = 2.0 * image - 1.0 diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8fbcf249aa..a6ab0cd06a 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -2,12 +2,13 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator ''' +import math import torch import torchvision.transforms as T import numpy as np import cv2 as cv import PIL -from PIL import Image, ImageFilter +from PIL import Image, ImageFilter, ImageOps from skimage.exposure.histogram_matching import match_histograms from einops import rearrange, repeat from ldm.invoke.devices import choose_autocast @@ -24,11 +25,128 @@ class Inpaint(Img2Img): self.mask_blur_radius = 0 super().__init__(model, precision) + # Outpaint support code + def get_tile_images(self, image: np.ndarray, width=8, height=8): + _nrows, _ncols, depth = image.shape + _strides = image.strides + + nrows, _m = divmod(_nrows, height) + ncols, _n = divmod(_ncols, width) + if _m != 0 or _n != 0: + return None + + return np.lib.stride_tricks.as_strided( + np.ravel(image), + shape=(nrows, ncols, height, width, depth), + strides=(height * _strides[0], width * _strides[1], *_strides), + writeable=False + ) + + def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: + a = np.asarray(im, dtype=np.uint8) + + tile_size = (tile_size, tile_size) + + # Get the image as tiles of a specified size + tiles = self.get_tile_images(a,*tile_size).copy() + + # Get the mask as tiles + tiles_mask = tiles[:,:,:,:,3] + + # Find any mask tiles with any fully transparent pixels (we will be replacing these later) + tmask_shape = tiles_mask.shape + tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) + n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) + tiles_mask = (tiles_mask > 0) + tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) + + # Get RGB tiles in single array and filter by the mask + tshape = tiles.shape + tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) + filtered_tiles = tiles_all[tiles_mask] + + if len(filtered_tiles) == 0: + return im + + # Find all invalid tiles and replace with a random valid tile + replace_count = (tiles_mask == False).sum() + rng = np.random.default_rng(seed = seed) + tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] + + # Convert back to an image + tiles_all = tiles_all.reshape(tshape) + tiles_all = tiles_all.swapaxes(1,2) + st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) + si = Image.fromarray(st, mode='RGBA') + + return si + + + def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: + npimg = np.asarray(mask, dtype=np.uint8) + + # Detect any partially transparent regions + npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) + + # Detect hard edges + npedge = cv.Canny(npimg, threshold1=100, threshold2=200) + + # Combine + npmask = npgradient + npedge + + # Expand + npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) + + new_mask = Image.fromarray(npmask) + + if edge_blur > 0: + new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) + + return ImageOps.invert(new_mask) + + + def seam_paint(self, + im: Image.Image, + seam_size: int, + seam_blur: int, + prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,strength, + noise + ) -> Image.Image: + hard_mask = self.pil_image.split()[-1].copy() + mask = self.mask_edge(hard_mask, seam_size, seam_blur) + + make_image = self.get_make_image( + prompt, + sampler, + steps, + cfg_scale, + ddim_eta, + conditioning, + init_image = im.copy().convert('RGBA'), + mask_image = mask.convert('RGB'), # Code currently requires an RGB mask + strength = strength, + mask_blur_radius = 0, + seam_size = 0 + ) + + result = make_image(noise) + + return result + + @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, mask_blur_radius: int = 8, - step_callback=None,inpaint_replace=False, **kwargs): + # Seam settings - when 0, doesn't fill seam + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + step_callback=None, + inpaint_replace=False, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at @@ -37,7 +155,17 @@ class Inpaint(Img2Img): if isinstance(init_image, PIL.Image.Image): self.pil_image = init_image - init_image = self._image_to_tensor(init_image) + + # Fill missing areas of original image + init_filled = self.tile_fill_missing( + self.pil_image.copy(), + seed = self.seed, + tile_size = tile_size + ) + init_filled.paste(init_image, (0,0), init_image.split()[-1]) + + # Create init tensor + init_image = self._image_to_tensor(init_filled.convert('RGB')) if isinstance(mask_image, PIL.Image.Image): self.pil_mask = mask_image @@ -73,7 +201,8 @@ class Inpaint(Img2Img): ) # move to latent space t_enc = int(strength * steps) - uc, c = conditioning + # todo: support cross-attention control + uc, c, _ = conditioning print(f">> target t_enc is {t_enc} steps") @@ -105,38 +234,56 @@ class Inpaint(Img2Img): mask = mask_image, init_latent = self.init_latent ) - return self.sample_to_image(samples) + + result = self.sample_to_image(samples) + + # Seam paint if this is our first pass (seam_size set to 0 during seam painting) + if seam_size > 0: + result = self.seam_paint( + result, + seam_size, + seam_blur, + prompt, + sampler, + seam_steps, + cfg_scale, + ddim_eta, + conditioning, + seam_strength, + x_T) + + return result return make_image - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - - if self.pil_image is None or self.pil_mask is None: - return gen_result - - pil_mask = self.pil_mask - pil_image = self.pil_image - mask_blur_radius = self.mask_blur_radius + def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image: # Get the original alpha channel of the mask if there is one. # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') - pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L') + pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - # Note that this doesn't use the mask, which would exclude some source image pixels from the - # histogram and cause slight color changes. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) - init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] - init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) + init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) - # Get numpy version - np_gen_result = np.asarray(gen_result, dtype=np.uint8) + # Get numpy version of result + np_image = np.asarray(image, dtype=np.uint8) + + # Mask and calculate mean and standard deviation + mask_pixels = init_a_pixels * init_mask_pixels > 0 + np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] + np_image_masked = np_image[mask_pixels, :] + + init_means = np_init_rgb_pixels_masked.mean(axis=0) + init_std = np_init_rgb_pixels_masked.std(axis=0) + gen_means = np_image_masked.mean(axis=0) + gen_std = np_image_masked.std(axis=0) # Color correct - np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + np_matched_result = np_image.copy() + np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') # Blur the mask out (into init image) by specified amount @@ -149,6 +296,16 @@ class Inpaint(Img2Img): blurred_init_mask = pil_init_mask # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) + matched_result.paste(base_image, (0,0), mask = blurred_init_mask) return matched_result + + def sample_to_image(self, samples)->Image.Image: + gen_result = super().sample_to_image(samples).convert('RGB') + + if self.pil_image is None or self.pil_mask is None: + return gen_result + + corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) + + return corrected_result diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py new file mode 100644 index 0000000000..e8426a9205 --- /dev/null +++ b/ldm/invoke/generator/omnibus.py @@ -0,0 +1,153 @@ +"""omnibus module to be used with the runwayml 9-channel custom inpainting model""" + +import torch +import numpy as np +from einops import repeat +from PIL import Image, ImageOps +from ldm.invoke.devices import choose_autocast +from ldm.invoke.generator.base import downsampling +from ldm.invoke.generator.img2img import Img2Img +from ldm.invoke.generator.txt2img import Txt2Img + +class Omnibus(Img2Img,Txt2Img): + def __init__(self, model, precision): + super().__init__(model, precision) + + def get_make_image( + self, + prompt, + sampler, + steps, + cfg_scale, + ddim_eta, + conditioning, + width, + height, + init_image = None, + mask_image = None, + strength = None, + step_callback=None, + threshold=0.0, + perlin=0.0, + **kwargs): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it. + """ + self.perlin = perlin + num_samples = 1 + + sampler.make_schedule( + ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + ) + + if isinstance(init_image, Image.Image): + if init_image.mode != 'RGB': + init_image = init_image.convert('RGB') + init_image = self._image_to_tensor(init_image) + + if isinstance(mask_image, Image.Image): + mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False) + + t_enc = steps + + if init_image is not None and mask_image is not None: # inpainting + masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero + + elif init_image is not None: # img2img + scope = choose_autocast(self.precision) + + with scope(self.model.device.type): + self.init_latent = self.model.get_first_stage_encoding( + self.model.encode_first_stage(init_image) + ) # move to latent space + + # create a completely black mask (1s) + mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device) + # and the masked image is just a copy of the original + masked_image = init_image + + else: # txt2img + init_image = torch.zeros(1, 3, height, width, device=self.model.device) + mask_image = torch.ones(1, 1, height, width, device=self.model.device) + masked_image = init_image + + self.init_latent = init_image + height = init_image.shape[2] + width = init_image.shape[3] + model = self.model + + def make_image(x_T): + with torch.no_grad(): + scope = choose_autocast(self.precision) + with scope(self.model.device.type): + + batch = self.make_batch_sd( + init_image, + mask_image, + masked_image, + prompt=prompt, + device=model.device, + num_samples=num_samples, + ) + + c = model.cond_stage_model.encode(batch["txt"]) + c_cat = list() + for ck in model.concat_keys: + cc = batch[ck].float() + if ck != model.masked_image_key: + bchw = [num_samples, 4, height//8, width//8] + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + + # cond + cond={"c_concat": [c_cat], "c_crossattn": [c]} + + # uncond cond + uc_cross = model.get_unconditional_conditioning(num_samples, "") + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} + shape = [model.channels, height//8, width//8] + + samples, _ = sampler.sample( + batch_size = 1, + S = steps, + x_T = x_T, + conditioning = cond, + shape = shape, + verbose = False, + unconditional_guidance_scale = cfg_scale, + unconditional_conditioning = uc_full, + eta = 1.0, + img_callback = step_callback, + threshold = threshold, + ) + if self.free_gpu_mem: + self.model.model.to("cpu") + return self.sample_to_image(samples) + + return make_image + + def make_batch_sd( + self, + image, + mask, + masked_image, + prompt, + device, + num_samples=1): + batch = { + "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), + "txt": num_samples * [prompt], + "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), + "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), + } + return batch + + def get_noise(self, width:int, height:int): + if self.init_latent is not None: + height = self.init_latent.shape[2] + width = self.init_latent.shape[3] + return Txt2Img.get_noise(self,width,height) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 86eb679c04..ba49d2ef55 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator import torch import numpy as np from ldm.invoke.generator.base import Generator +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent + class Txt2Img(Generator): def __init__(self, model, precision): @@ -19,7 +21,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c = conditioning + uc, c, extra_conditioning_info = conditioning @torch.no_grad() def make_image(x_T): @@ -43,6 +45,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, + extra_conditioning_info = extra_conditioning_info, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 4e68d94875..759ba2dba4 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -5,9 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator import torch import numpy as np import math -from ldm.invoke.generator.base import Generator +from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler - +from ldm.invoke.generator.omnibus import Omnibus +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from PIL import Image class Txt2Img2Img(Generator): def __init__(self, model, precision): @@ -22,31 +24,29 @@ class Txt2Img2Img(Generator): Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ - uc, c = conditioning + uc, c, extra_conditioning_info = conditioning + scale_dim = min(width, height) + scale = 512 / scale_dim + + init_width = math.ceil(scale * width / 64) * 64 + init_height = math.ceil(scale * height / 64) * 64 @torch.no_grad() - def make_image(x_T): - - trained_square = 512 * 512 - actual_square = width * height - scale = math.sqrt(trained_square / actual_square) + def make_image(x_T): - init_width = math.ceil(scale * width / 64) * 64 - init_height = math.ceil(scale * height / 64) * 64 - shape = [ self.latent_channels, init_height // self.downsampling_factor, init_width // self.downsampling_factor, ] - + sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) - + #x = self.get_noise(init_width, init_height) x = x_T - + if self.free_gpu_mem and self.model.model.device != self.model.device: self.model.model.to(self.model.device) @@ -60,17 +60,18 @@ class Txt2Img2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, eta = ddim_eta, - img_callback = step_callback + img_callback = step_callback, + extra_conditioning_info = extra_conditioning_info ) - + print( f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" ) - + # resizing samples = torch.nn.functional.interpolate( - samples, - size=(height // self.downsampling_factor, width // self.downsampling_factor), + samples, + size=(height // self.downsampling_factor, width // self.downsampling_factor), mode="bilinear" ) @@ -94,6 +95,8 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, + extra_conditioning_info=extra_conditioning_info, + all_timesteps_count=steps ) if self.free_gpu_mem: @@ -101,8 +104,49 @@ class Txt2Img2Img(Generator): return self.sample_to_image(samples) - return make_image - + # in the case of the inpainting model being loaded, the trick of + # providing an interpolated latent doesn't work, so we transiently + # create a 512x512 PIL image, upscale it, and run the inpainting + # over it in img2img mode. Because the inpaing model is so conservative + # it doesn't change the image (much) + def inpaint_make_image(x_T): + omnibus = Omnibus(self.model,self.precision) + result = omnibus.generate( + prompt, + sampler=sampler, + width=init_width, + height=init_height, + step_callback=step_callback, + steps = steps, + cfg_scale = cfg_scale, + ddim_eta = ddim_eta, + conditioning = conditioning, + **kwargs + ) + assert result is not None and len(result)>0,'** txt2img failed **' + image = result[0][0] + interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS) + print(kwargs.pop('init_image',None)) + result = omnibus.generate( + prompt, + sampler=sampler, + init_image=interpolated_image, + width=width, + height=height, + seed=result[0][1], + step_callback=step_callback, + steps = steps, + cfg_scale = cfg_scale, + ddim_eta = ddim_eta, + conditioning = conditioning, + **kwargs + ) + return result[0][0] + + if sampler.uses_inpainting_model(): + return inpaint_make_image + else: + return make_image # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height,scale = True): @@ -116,7 +160,7 @@ class Txt2Img2Img(Generator): else: scaled_width = width scaled_height = height - + device = self.model.device if self.use_mps_noise or device.type == 'mps': return torch.randn([1, @@ -130,3 +174,4 @@ class Txt2Img2Img(Generator): scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], device=device) + diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index f580dfba25..f972a9eb16 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import traceback import os from sys import getrefcount from omegaconf import OmegaConf @@ -73,6 +74,7 @@ class ModelCache(object): self.models[model_name]['hash'] = hash except Exception as e: print(f'** model {model_name} could not be loaded: {str(e)}') + print(traceback.format_exc()) print(f'** restoring {self.current_model}') self.get_model(self.current_model) return None diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..e6856761ac --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,686 @@ +import string +from typing import Union, Optional +import re +import pyparsing as pp + +class Prompt(): + """ + Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of + the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects + that are to be blended together are stored individuall as Prompt objects. + + Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction + to produce a FlattenedPrompt. + """ + def __init__(self, parts: list): + for c in parts: + if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class BaseFragment: + pass + +class FlattenedPrompt(): + """ + A Prompt that has been passed through flatten(). Its children can be readily tokenized. + """ + def __init__(self, parts: list=[]): + self.children = [] + for part in parts: + self.append(part) + + def append(self, fragment: Union[list, BaseFragment, tuple]): + # verify type correctness + if type(fragment) is list: + for x in fragment: + self.append(x) + elif issubclass(type(fragment), BaseFragment): + self.children.append(fragment) + elif type(fragment) is tuple: + # upgrade tuples to Fragments + if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int): + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + self.children.append(Fragment(fragment[0], fragment[1])) + else: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + + @property + def is_empty(self): + return len(self.children) == 0 or \ + (len(self.children) == 1 and len(self.children[0].text) == 0) + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + + +class Fragment(BaseFragment): + """ + A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer. + """ + def __init__(self, text: str, weight: float=1): + assert(type(text) is str) + if '\\"' in text or '\\(' in text or '\\)' in text: + #print("Fragment converting escaped \( \) \\\" into ( ) \"") + text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"') + self.text = text + self.weight = float(weight) + + def __repr__(self): + return "Fragment:'"+self.text+"'@"+str(self.weight) + def __eq__(self, other): + return type(other) is Fragment \ + and other.text == self.text \ + and other.weight == self.weight + +class Attention(): + """ + Nestable weight control for fragments. Each object in the children array may in turn be an Attention object; + weights should be considered to accumulate as the tree is traversed to deeper levels of nesting. + + Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object. + """ + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + +class CrossAttentionControlledFragment(BaseFragment): + pass + +class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): + """ + A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt. + Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an + "edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively, + the result should be an "edited" image that looks like the "original" image with concepts swapped. + + eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look + almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The + CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped: + CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]) + or it may represent a larger portion of the token sequence: + CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')], + edited=[Fragment('a smiling dog sitting on a car')]) + + In either case expect it to be embedded in a Prompt or FlattenedPrompt: + FlattenedPrompt([ + Fragment('a'), + CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]), + Fragment('sitting on a car') + ]) + """ + def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): + self.original = original + self.edited = edited + + default_options = { + 's_start': 0.0, + 's_end': 0.2062994740159002, # ~= shape_freedom=0.5 + 't_start': 0.0, + 't_end': 1.0 + } + merged_options = default_options + if options is not None: + shape_freedom = options.pop('shape_freedom', None) + if shape_freedom is not None: + # high shape freedom = SD can do what it wants with the shape of the object + # high shape freedom => s_end = 0 + # low shape freedom => s_end = 1 + # shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0, + # and there is very little perceptible difference as s_end increases above 0.5 + # so for shape_freedom = 0.5 we probably want s_end to be 0.2 + # -> cube root and subtract from 1.0 + merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.) + #print('converted shape_freedom argument to', merged_options) + merged_options.update(options) + + self.options = merged_options + + def __repr__(self): + return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})" + def __eq__(self, other): + return type(other) is CrossAttentionControlSubstitute \ + and other.original == self.original \ + and other.edited == self.edited \ + and other.options == self.options + + +class CrossAttentionControlAppend(CrossAttentionControlledFragment): + def __init__(self, fragment: Fragment): + self.fragment = fragment + def __repr__(self): + return "CrossAttentionControlAppend:",self.fragment + def __eq__(self, other): + return type(other) is CrossAttentionControlAppend \ + and other.fragment == self.fragment + + + +class Conjunction(): + """ + Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged + by weighted sum in latent space. + """ + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt + or type(x) is Blend + or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + """ + Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a + weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the + Prompts. + """ + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) + + + def parse_conjunction(self, prompt: str) -> Conjunction: + ''' + :param prompt: The prompt string to parse + :return: a Conjunction representing the parsed results. + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.conjunction.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def parse_legacy_blend(self, text: str) -> Optional[Blend]: + weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) + if len(weighted_subprompts) <= 1: + return None + strings = [x[0] for x in weighted_subprompts] + weights = [x[1] for x in weighted_subprompts] + + parsed_conjunctions = [self.parse_conjunction(x) for x in strings] + flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] + + return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) + + + def flatten(self, root: Conjunction) -> Conjunction: + """ + Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends, + producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects + that can be readily tokenized without the need to walk a complex tree structure. + + :param root: The Conjunction to flatten. + :return: A Conjunction containing the result of flattening each of the prompts in the passed-in root. + """ + + #print("flattening", root) + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + if type(x) is CrossAttentionControlSubstitute: + original_fused = fuse_fragments(x.original) + edited_fused = fuse_fragments(x.edited) + result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options)) + else: + last_weight = result[-1].weight \ + if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ + else None + this_text = x.text + this_weight = x.weight + if last_weight is not None and last_weight == this_weight: + last_text = result[-1].text + result[-1] = Fragment(last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+' pr ') + #print(prefix, " ParseResults expanded, results is now", results) + elif type(node) is Attention: + # if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for index,c in enumerate(node.children): + results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ") + elif type(node) is Fragment: + results += [Fragment(node.text, node.weight*weight_scale)] + elif type(node) is CrossAttentionControlSubstitute: + original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') + edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') + results += [CrossAttentionControlSubstitute(original, edited, options=node.options)] + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node).__name__, "results is", results) + return results + + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + + #print("flattened to", flattened_parts) + + weights = root.weights + return Conjunction(flattened_parts, weights) + + + +def build_parser_syntax(attention_plus_base: float, attention_minus_base: float): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + quotes = pp.Literal('"').suppress() + comma = pp.Literal(",").suppress() + + # accepts int or float notation, always maps to float + number = pp.pyparsing_common.real | \ + pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float)) + + attention = pp.Forward() + quoted_fragment = pp.Forward() + parenthesized_fragment = pp.Forward() + cross_attention_substitute = pp.Forward() + + def make_text_fragment(x): + #print("### making fragment for", x) + if type(x[0]) is Fragment: + assert(False) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + #print(f'converting {type(x).__name__} to Fragment') + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str): + escapes = [] + for c in escaped_chars_to_ignore: + escapes.append(pp.Literal('\\'+c)) + return pp.Combine(pp.OneOrMore( + pp.MatchFirst(escapes + [pp.CharsNotIn( + string.whitespace + escaped_chars_to_ignore, + exact=1 + )]) + )) + + + + def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): + #print(f"parsing fragment string for {x}") + fragment_string = x[0] + #print(f"ppparsing fragment string \"{fragment_string}\"") + + if len(fragment_string.strip()) == 0: + return Fragment('') + + if in_quotes: + # escape unescaped quotes + fragment_string = fragment_string.replace('"', '\\"') + + #fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment)))) + try: + result = pp.Group(pp.MatchFirst([ + pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'), + pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd() + ])).set_name('blend-result').set_debug(False).parse_string(fragment_string) + #print("parsed to", result) + return result + except pp.ParseException as e: + #print("parse_fragment_str couldn't parse prompt string:", e) + raise + + quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') + quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment') + + escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"') + escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(') + escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')') + escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"') + + empty = ( + (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | + (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') + + + def not_ends_with_swap(x): + #print("trying to match:", x) + return not x[0].endswith('.swap') + + unquoted_word = (pp.Combine(pp.OneOrMore( + escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash | + (pp.CharsNotIn(string.whitespace + '\\"()', exact=1) + ))) + # don't whitespace when the next word starts with +, eg "badly +formed" + + (pp.White().suppress() | + # don't eat +/- + pp.NotAny(pp.Word('+') | pp.Word('-')) + ) + ) + + unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False) + #print(unquoted_fragment.parse_string("cat.swap(dog)")) + + parenthesized_fragment << (lparen + + pp.Or([ + (parenthesized_fragment), + (quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False), + (pp.Combine(pp.OneOrMore( + escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash | + pp.CharsNotIn(string.whitespace + '\\"()', exact=1) | + pp.White() + )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)), + pp.Empty() + ]) + rparen) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) + + debug_attention = False + # attention control of the form (phrase)+ / (phrase)+ / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention_with_parens = pp.Forward() + attention_without_parens = pp.Forward() + + attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_foot")\ + .set_debug(False) + attention_with_parens <<= pp.Group( + lparen + + pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens | + (pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0]) + ) + + rparen + attention_with_parens_foot) + attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) + + attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots') + attention_without_parens <<= pp.Group(pp.MatchFirst([ + quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot, + pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x)) + + attention_without_parens_foot#.leave_whitespace() + ])) + attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) + + + attention << pp.MatchFirst([attention_with_parens, + attention_without_parens + ]) + attention.set_name('attention') + + def make_attention(x): + #print("entered make_attention with", x) + children = x[0][:-1] + weight_raw = x[0][-1] + weight = 1.0 + if type(weight_raw) is float or type(weight_raw) is int: + weight = weight_raw + elif type(weight_raw) is str: + base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base + weight = pow(base, len(weight_raw)) + + #print("making Attention from", children, "with weight", weight) + + return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children]) + + attention_with_parens.set_parse_action(make_attention) + attention_without_parens.set_parse_action(make_attention) + + #print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1")) + + # cross-attention control + empty_string = ((lparen + rparen) | + pp.Literal('""').suppress() | + (lparen + pp.Literal('""').suppress() + rparen) + ).set_parse_action(lambda x: Fragment("")) + empty_string.set_name('empty_string') + + # cross attention control + debug_cross_attention_control = False + original_fragment = pp.MatchFirst([ + quoted_fragment.set_debug(debug_cross_attention_control), + parenthesized_fragment.set_debug(debug_cross_attention_control), + pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"), + empty_string.set_debug(debug_cross_attention_control), + ]) + # support keyword=number arguments + cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")]) + cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number) + edited_fragment = pp.MatchFirst([ + (lparen + rparen).set_parse_action(lambda x: Fragment('')), + lparen + + (quoted_fragment | attention | + pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment))) + ) + + pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) + + rparen, + parenthesized_fragment + ]) + cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment + + original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) + edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) + + def make_cross_attention_substitute(x): + #print("making cacs for", x[0], "->", x[1], "with options", x.as_dict()) + #if len(x>2): + cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict()) + #print("made", cacs) + return cacs + cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + + # root prompt definition + debug_root_prompt = False + prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt), + attention.set_debug(debug_root_prompt), + quoted_fragment.set_debug(debug_root_prompt), + parenthesized_fragment.set_debug(debug_root_prompt), + unquoted_word.set_debug(debug_root_prompt), + empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)]) + ) + pp.StringEnd()) \ + .set_name('prompt') \ + .set_parse_action(lambda x: Prompt(x)) \ + .set_debug(debug_root_prompt) + + #print("parsing test:", prompt.parse_string("spaced eyes--")) + #print("parsing test:", prompt.parse_string("eyes--")) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + #print(f' b parsing \'{x_unquoted}\'') + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + debug_blend=False + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend) + blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend) + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(debug_blend) + + def make_blend(x): + prompts = x[0][0] + weights = x[0][1] + normalize = True + if weights[-1] == 'no_normalize': + normalize = False + weights = weights[:-1] + return Blend(prompts=prompts, weights=weights, normalize_weights=normalize) + + blend.set_parse_action(make_blend) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction') + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction, prompt + + + +def split_weighted_subprompts(text, skip_normalize=False)->list: + """ + Legacy blend parsing. + + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( + match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / max(len(parsed_prompts), 1) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False, weight=1): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', 'x` ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) diff --git a/ldm/invoke/restoration/outcrop.py b/ldm/invoke/restoration/outcrop.py index 017d9de7e1..0c4831dd84 100644 --- a/ldm/invoke/restoration/outcrop.py +++ b/ldm/invoke/restoration/outcrop.py @@ -89,6 +89,9 @@ class Outcrop(object): def _extend(self,image:Image,pixels:int)-> Image: extended_img = Image.new('RGBA',(image.width,image.height+pixels)) + mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \ + else pixels *2 + # first paste places old image at top of extended image, stretch # it, and applies a gaussian blur to it # take the top half region, stretch and paste it @@ -105,7 +108,9 @@ class Outcrop(object): # now make the top part transparent to use as a mask alpha = extended_img.getchannel('A') - alpha.paste(0,(0,0,extended_img.width,pixels*2)) + alpha.paste(0,(0,0,extended_img.width,mask_height)) extended_img.putalpha(alpha) + extended_img.save('outputs/curly_extended.png') + return extended_img diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 359f5688d1..3db7b6fd73 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -66,7 +66,7 @@ class VQModel(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) - print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') + print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.') if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py new file mode 100644 index 0000000000..1e5b073a3d --- /dev/null +++ b/ldm/models/diffusion/cross_attention_control.py @@ -0,0 +1,238 @@ +from enum import Enum + +import torch + +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + +class CrossAttentionControl: + + class Arguments: + def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] + :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) + :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. + """ + # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes + + if edited_conditioning is not None: + assert len(edit_opcodes) == len(edit_options), \ + "there must be 1 edit_options dict for each edit_opcodes tuple" + non_none_edit_options = [x for x in edit_options if x is not None] + assert len(non_none_edit_options)>0, "missing edit_options" + if len(non_none_edit_options)>1: + print('warning: cross-attention control options are not working properly for >1 edit') + self.edit_options = non_none_edit_options[0] + + class Context: + def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): + """ + :param arguments: Arguments for the cross-attention control process + :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) + """ + self.arguments = arguments + self.step_count = step_count + + @classmethod + def remove_cross_attention_control(cls, model): + cls.remove_attention_function(model) + + @classmethod + def setup_cross_attention_control(cls, model, + cross_attention_control_args: Arguments + ): + """ + Inject attention parameters and functions into the passed in model to enable cross attention editing. + + :param model: The unet model to inject into. + :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations + :return: None + """ + + # adapted from init_attention_edit + device = cross_attention_control_args.edited_conditioning.device + + # urgh. should this be hardcoded? + max_length = 77 + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) + + cls.inject_attention_function(model) + + + class CrossAttentionType(Enum): + SELF = 1 + TOKENS = 2 + + @classmethod + def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\ + -> list['CrossAttentionControl.CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if percent_through is None: + return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] + + opts = context.arguments.edit_options + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(cls.CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(cls.CrossAttentionType.TOKENS) + return to_control + + + @classmethod + def get_attention_modules(cls, model, which: CrossAttentionType): + which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" + return [module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + @classmethod + def clear_requests(cls, model): + self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = False + m.use_last_attn_slice = False + + @classmethod + def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + # clear out the saved slice in case the outermost dim changes + m.last_attn_slice = None + m.save_last_attn_slice = True + + @classmethod + def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + m.use_last_attn_slice = True + + + + @classmethod + def inject_attention_function(cls, unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + + def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + + #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + + attn_slice = suggested_attention_slice + if dim is not None: + start = offset + end = start+slice_size + #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + #else: + # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + + + if self.use_last_attn_slice: + this_attn_slice = attn_slice + if self.last_attn_slice_mask is not None: + # indices and mask operate on dim=2, no need to slice + base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + base_attn_slice_mask = self.last_attn_slice_mask + if dim is None: + base_attn_slice = base_attn_slice_full + #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 0: + base_attn_slice = base_attn_slice_full[start:end] + #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 1: + base_attn_slice = base_attn_slice_full[:, start:end] + #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + + attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ + base_attn_slice * base_attn_slice_mask + else: + if dim is None: + attn_slice = self.last_attn_slice + #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 0: + attn_slice = self.last_attn_slice[start:end] + #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 1: + attn_slice = self.last_attn_slice[:, start:end] + #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + + if self.save_last_attn_slice: + if dim is None: + self.last_attn_slice = attn_slice + elif dim == 0: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + #print("no last_attn_slice: shape now", self.last_attn_slice.shape) + elif self.last_attn_slice.shape[0] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) + assert(self.last_attn_slice.shape[0] == end) + #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + else: + # no need to grow + self.last_attn_slice[start:end] = attn_slice + #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + + elif dim == 1: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + elif self.last_attn_slice.shape[1] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) + assert(self.last_attn_slice.shape[1] == end) + else: + # no need to grow + self.last_attn_slice[:, start:end] = attn_slice + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + if dim is None: + weights = self.last_attn_slice_weights + elif dim == 0: + weights = self.last_attn_slice_weights[start:end] + elif dim == 1: + weights = self.last_attn_slice_weights[:, start:end] + attn_slice = attn_slice * weights + + return attn_slice + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.last_attn_slice = None + module.last_attn_slice_indices = None + module.last_attn_slice_mask = None + module.use_last_attn_weights = False + module.use_last_attn_slice = False + module.save_last_attn_slice = False + module.set_attention_slice_wrangler(attention_slice_wrangler) + + @classmethod + def remove_attention_function(cls, unet): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_attention_slice_wrangler(None) + diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..e5a502f977 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -1,10 +1,7 @@ """SAMPLING ONLY.""" import torch -import numpy as np -from tqdm import tqdm -from functools import partial -from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like @@ -12,6 +9,21 @@ class DDIMSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps,device) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + else: + self.invokeai_diffuser.remove_cross_attention_control() + + # This is the central routine @torch.no_grad() def p_sample( @@ -29,6 +41,7 @@ class DDIMSampler(Sampler): corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -37,16 +50,17 @@ class DDIMSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond + # step_index counts in the opposite direction to index + step_index = step_count-(index+1) + e_t = self.invokeai_diffuser.do_diffusion_step( + x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index ) - if score_corrector is not None: assert self.model.parameterization == 'eps' e_t = score_corrector.modify_score( diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..a5989532c4 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -19,6 +19,7 @@ from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig import urllib from ldm.util import ( @@ -120,7 +121,7 @@ class DDPM(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) - print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') + print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.') self.use_scheduler = scheduler_config is not None if self.use_scheduler: @@ -820,21 +821,21 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, **kwargs): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager,**kwargs ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: - c = self.cond_stage_model(c) + c = self.cond_stage_model(c, **kwargs) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs) return c def meshgrid(self, h, w): @@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM): return samples, intermediates + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + @torch.no_grad() def log_images( self, @@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule): cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) + xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] @@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion): cond_img = torch.stack(bbox_imgs, dim=0) logs['bbox_image'] = cond_img return logs + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + finetune_keys=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..54d78b302b 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -1,16 +1,16 @@ """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" + import k_diffusion as K import torch -import torch.nn as nn -from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.sampler import Sampler -from ldm.util import rand_perlin_2d -from ldm.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) +from torch import nn + +from .sampler import Sampler +from .shared_invokeai_diffusion import InvokeAIDiffuserComponent + + +# at this threshold, the scheduler will stop using the Karras +# noise schedule and start using the model's schedule +STEP_THRESHOLD = 29 def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -33,12 +33,21 @@ class CFGDenoiser(nn.Module): self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + self.invokeai_diffuser = InvokeAIDiffuserComponent(model, + model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) + + def prepare_to_sample(self, t_enc, **kwargs): + + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + else: + self.invokeai_diffuser.remove_cross_attention_control() + def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -46,8 +55,7 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) - + return cfg_apply_threshold(next_x, thresh) class KSampler(Sampler): def __init__(self, model, schedule='lms', device=None, **kwargs): @@ -60,16 +68,9 @@ class KSampler(Sampler): self.sigmas = None self.ds = None self.s_in = None - - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model( - x_in, sigma_in, cond=cond_in - ).chunk(2) - return uncond + (cond - uncond) * cond_scale - + self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD) + if self.karras_max is None: + self.karras_max = STEP_THRESHOLD def make_schedule( self, @@ -98,8 +99,13 @@ class KSampler(Sampler): rho=7., device=self.device, ) - self.sigmas = self.model_sigmas - #self.sigmas = self.karras_sigmas + + if ddim_num_steps >= self.karras_max: + print(f'>> Ksampler using model noise schedule (steps > {self.karras_max})') + self.sigmas = self.model_sigmas + else: + print(f'>> Ksampler using karras noise schedule (steps <= {self.karras_max})') + self.sigmas = self.karras_sigmas # ALERT: We are completely overriding the sample() method in the base class, which # means that inpainting will not work. To get this to work we need to be able to @@ -118,6 +124,7 @@ class KSampler(Sampler): use_original_steps=False, init_latent = None, mask = None, + **kwargs ): samples,_ = self.sample( batch_size = 1, @@ -129,7 +136,8 @@ class KSampler(Sampler): unconditional_conditioning = unconditional_conditioning, img_callback = img_callback, x0 = init_latent, - mask = mask + mask = mask, + **kwargs ) return samples @@ -163,6 +171,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + extra_conditioning_info=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -181,7 +190,6 @@ class KSampler(Sampler): ) # sigmas are set up in make_schedule - we take the last steps items - total_steps = len(self.sigmas) sigmas = self.sigmas[-S-1:] # x_T is variation noise. When an init image is provided (in x0) we need to add @@ -195,19 +203,21 @@ class KSampler(Sampler): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale, } print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') - return ( + sampling_result = ( K.sampling.__dict__[f'sample_{self.schedule}']( model_wrap_cfg, x, sigmas, extra_args=extra_args, callback=route_callback ), None, ) + return sampling_result # this code will support inpainting if and when ksampler API modified or # a workaround is found. @@ -220,6 +230,7 @@ class KSampler(Sampler): index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + extra_conditioning_info=None, **kwargs, ): if self.model_wrap is None: @@ -245,6 +256,7 @@ class KSampler(Sampler): # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 + self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info) img = K.sampling.__dict__[f'_{self.schedule}']( self.model_wrap, img, @@ -269,7 +281,7 @@ class KSampler(Sampler): else: return x - def prepare_to_sample(self,t_enc): + def prepare_to_sample(self,t_enc,**kwargs): self.t_enc = t_enc self.model_wrap = None self.ds = None @@ -281,3 +293,6 @@ class KSampler(Sampler): ''' return self.model.inner_model.q_sample(x0,ts) + def conditioning_key(self)->str: + return self.model.inner_model.model.conditioning_key + diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 9e722eb932..5124badcd1 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,6 +5,7 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like @@ -13,6 +14,18 @@ class PLMSSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps, device) + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + else: + self.invokeai_diffuser.remove_cross_attention_control() + + # this is the essential routine @torch.no_grad() def p_sample( @@ -32,6 +45,7 @@ class PLMSSampler(Sampler): unconditional_conditioning=None, old_eps=[], t_next=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -41,18 +55,15 @@ class PLMSSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model( - x_in, t_in, c_in - ).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) - + # step_index counts in the opposite direction to index + step_index = step_count-(index+1) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' e_t = score_corrector.modify_score( diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..7b09515bdc 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -2,13 +2,13 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc - ''' import torch import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.modules.diffusionmodules.util import ( make_ddim_sampling_parameters, @@ -24,6 +24,8 @@ class Sampler(object): self.ddpm_num_timesteps = steps self.schedule = schedule self.device = device or choose_torch_device() + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -158,6 +160,18 @@ class Sampler(object): **kwargs, ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + # check to see if make_schedule() has run, and if not, run it if self.ddim_timesteps is None: self.make_schedule( @@ -190,10 +204,11 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, steps=S, + **kwargs ) return samples, intermediates - #torch.no_grad() + @torch.no_grad() def do_sampling( self, cond, @@ -214,6 +229,7 @@ class Sampler(object): unconditional_guidance_scale=1.0, unconditional_conditioning=None, steps=None, + **kwargs ): b = shape[0] time_range = ( @@ -231,7 +247,7 @@ class Sampler(object): dynamic_ncols=True, ) old_eps = [] - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs) img = self.get_initial_image(x_T,shape,total_steps) # probably don't need this at all @@ -274,6 +290,7 @@ class Sampler(object): unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, + step_count=steps ) img, pred_x0, e_t = outs @@ -305,8 +322,9 @@ class Sampler(object): use_original_steps=False, init_latent = None, mask = None, + all_timesteps_count = None, + **kwargs ): - timesteps = ( np.arange(self.ddpm_num_timesteps) if use_original_steps @@ -321,7 +339,7 @@ class Sampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -353,6 +371,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, t_next = ts_next, + step_count=len(self.ddim_timesteps) ) x_dec, pred_x0, e_t = outs @@ -411,3 +430,21 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + def conditioning_key(self)->str: + return self.model.model.conditioning_key + + def uses_inpainting_model(self)->bool: + return self.conditioning_key() in ('hybrid','concat') + + def adjust_settings(self,**kwargs): + ''' + This is a catch-all method for adjusting any instance variables + after the sampler is instantiated. No type-checking performed + here, so use with care! + ''' + for k in kwargs.keys(): + try: + setattr(self,k,kwargs[k]) + except AttributeError: + print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}') diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py new file mode 100644 index 0000000000..90c8ebb981 --- /dev/null +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -0,0 +1,216 @@ +from math import ceil +from typing import Callable, Optional, Union + +import torch + +from ldm.models.diffusion.cross_attention_control import CrossAttentionControl + + +class InvokeAIDiffuserComponent: + ''' + The aim of this component is to provide a single place for code that can be applied identically to + all InvokeAI diffusion procedures. + + At the moment it includes the following features: + * Cross attention control ("prompt2prompt") + * Hybrid conditioning (used for inpainting) + ''' + + + class ExtraConditioningInfo: + def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): + self.cross_attention_control_args = cross_attention_control_args + + @property + def wants_cross_attention_control(self): + return self.cross_attention_control_args is not None + + def __init__(self, model, model_forward_callback: + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + ): + """ + :param model: the unet model to pass through to cross attention control + :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) + """ + self.model = model + self.model_forward_callback = model_forward_callback + + + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): + self.conditioning = conditioning + self.cross_attention_control_context = CrossAttentionControl.Context( + arguments=self.conditioning.cross_attention_control_args, + step_count=step_count + ) + CrossAttentionControl.setup_cross_attention_control(self.model, + cross_attention_control_args=self.conditioning.cross_attention_control_args + ) + #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct + #todo: apply edit_options using step_count + + def remove_cross_attention_control(self): + self.conditioning = None + self.cross_attention_control_context = None + CrossAttentionControl.remove_cross_attention_control(self.model) + + + def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, + unconditioning: Union[torch.Tensor,dict], + conditioning: Union[torch.Tensor,dict], + unconditional_guidance_scale: float, + step_index: Optional[int]=None + ): + """ + :param x: current latents + :param sigma: aka t, passed to the internal model to control how much denoising will occur + :param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] + :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] + :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has + :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas . + :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. + """ + + CrossAttentionControl.clear_requests(self.model) + + cross_attention_control_types_to_do = [] + if self.cross_attention_control_context is not None: + percent_through = self.estimate_percent_through(step_index, sigma) + cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) + + wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) + wants_hybrid_conditioning = isinstance(conditioning, dict) + + if wants_hybrid_conditioning: + unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning) + elif wants_cross_attention_control: + unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + else: + unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning) + + # to scale how much effect conditioning has, calculate the changes it does and then scale that + scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale + combined_next_x = unconditioned_next_x + scaled_delta + + return combined_next_x + + + # methods below are called from do_diffusion_step and should be considered private to this class. + + def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): + # fast batched path + x_twice = torch.cat([x] * 2) + sigma_twice = torch.cat([sigma] * 2) + both_conditionings = torch.cat([unconditioning, conditioning]) + unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, + both_conditionings).chunk(2) + return unconditioned_next_x, conditioned_next_x + + + def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): + assert isinstance(conditioning, dict) + assert isinstance(unconditioning, dict) + x_twice = torch.cat([x] * 2) + sigma_twice = torch.cat([sigma] * 2) + both_conditionings = dict() + for k in conditioning: + if isinstance(conditioning[k], list): + both_conditionings[k] = [ + torch.cat([unconditioning[k][i], conditioning[k][i]]) + for i in range(len(conditioning[k])) + ] + else: + both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) + unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) + return unconditioned_next_x, conditioned_next_x + + + def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) + # slower non-batched path (20% slower on mac MPS) + # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of + # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. + # This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8) + # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, + # representing batched uncond + cond, but then when it comes to applying the saved attention, the + # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) + # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + + # process x using the original prompt, saving the attention maps + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_save_attention_maps(self.model, type) + _ = self.model_forward_callback(x, sigma, conditioning) + CrossAttentionControl.clear_requests(self.model) + + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning + conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) + + CrossAttentionControl.clear_requests(self.model) + + return unconditioned_next_x, conditioned_next_x + + def estimate_percent_through(self, step_index, sigma): + if step_index is not None and self.cross_attention_control_context is not None: + # percent_through will never reach 1.0 (but this is intended) + return float(step_index) / float(self.cross_attention_control_context.step_count) + # find the best possible index of the current sigma in the sigma sequence + sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1] + # flip because sigmas[0] is for the fully denoised image + # percent_through must be <1 + return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) + # print('estimated percent_through', percent_through, 'from sigma', sigma.item()) + + + # todo: make this work + @classmethod + def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) # aka sigmas + + deltas = None + uncond_latents = None + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + + # below is fugly omg + num_actual_conditionings = len(c_or_weighted_c_list) + conditionings = [uc] + [c for c,weight in weighted_cond_list] + weights = [1] + [weight for c,weight in weighted_cond_list] + chunk_count = ceil(len(conditionings)/2) + deltas = None + for chunk_index in range(chunk_count): + offset = chunk_index*2 + chunk_size = min(2, len(conditionings)-offset) + + if chunk_size == 1: + c_in = conditionings[offset] + latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) + latents_b = None + else: + c_in = torch.cat(conditionings[offset:offset+2]) + latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) + + # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining + if chunk_index == 0: + uncond_latents = latents_a + deltas = latents_b - uncond_latents + else: + deltas = torch.cat((deltas, latents_a - uncond_latents)) + if latents_b is not None: + deltas = torch.cat((deltas, latents_b - uncond_latents)) + + # merge the weighted deltas together into a single merged delta + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + normalize = False + if normalize: + per_delta_weights /= torch.sum(per_delta_weights) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) + + # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) + # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) + + return uncond_latents + deltas_merged * global_guidance_scale + diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ef9c2d3e65..4c36fa8a6c 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,5 +1,7 @@ from inspect import isfunction import math +from typing import Callable + import torch import torch.nn.functional as F from torch import nn, einsum @@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module): return x+h_ + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() @@ -170,46 +173,71 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - def einsum_op_compvis(self, q, k, v): - s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) - return einsum('b i j, b j d -> b i d', s, v) + self.attention_slice_wrangler = None - def einsum_op_slice_0(self, q, k, v, slice_size): + def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + ''' + Set custom attention calculator to be called when attention is calculated + :param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size), + which returns either the suggested_attention_slice or an adjusted equivalent. + self is the current CrossAttention module for which the callback is being invoked. + attention_scores are the scores for attention + suggested_attention_slice is a softmax(dim=-1) over attention_scores + dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If dim is >= 0, offset and slice_size specify the slice start and length. + + Pass None to use the default attention calculation. + :return: + ''' + self.attention_slice_wrangler = wrangler + + def einsum_lowest_level(self, q, k, v, dim, offset, slice_size): + # calculate attention scores + attention_scores = einsum('b i d, b j d -> b i j', q, k) + # calculate attenion slice by taking the best scores for each latent pixel + default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) + if self.attention_slice_wrangler is not None: + attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size) + else: + attention_slice = default_attention_slice + + return einsum('b i j, b j d -> b i d', attention_slice, v) + + def einsum_op_slice_dim0(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r - def einsum_op_slice_1(self, q, k, v, slice_size): + def einsum_op_slice_dim1(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_op_compvis(q, k, v) + return self.einsum_lowest_level(q, k, v, None, None, None) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_1(q, k, v, slice_size) + return self.einsum_op_slice_dim1(q, k, v, slice_size) def einsum_op_mps_v2(self, q, k, v): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_op_compvis(q, k, v) + return self.einsum_lowest_level(q, k, v, None, None, None) else: - return self.einsum_op_slice_0(q, k, v, 1) + return self.einsum_op_slice_dim0(q, k, v, 1) def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return self.einsum_op_compvis(q, k, v) + return self.einsum_lowest_level(q, k, v, None, None, None) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: - return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) + return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) def einsum_op_cuda(self, q, k, v): stats = torch.cuda.memory_stats(q.device) @@ -221,7 +249,7 @@ class CrossAttention(nn.Module): # Divide factor of safety as there's copying and fragmentation return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - def einsum_op(self, q, k, v): + def get_attention_mem_efficient(self, q, k, v): if q.device.type == 'cuda': return self.einsum_op_cuda(q, k, v) @@ -244,8 +272,13 @@ class CrossAttention(nn.Module): del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = self.einsum_op(q, k, v) - return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + + r = self.get_attention_mem_efficient(q, k, v) + + hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) + return self.to_out(hidden_states) + + class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..cf8644a7fb 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -454,6 +456,223 @@ class FrozenCLIPEmbedder(AbstractEncoder): def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + + fragment_weights_key = "fragment_weights" + return_tokens_key = "return_tokens" + + def forward(self, text: list, **kwargs): + ''' + + :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different + weights shall be applied. + :param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights + for the prompt fragments. In this case text must contain batches of lists of prompt fragments. + :return: A tensor of shape (B, 77, 768) containing weighted embeddings + ''' + if self.fragment_weights_key not in kwargs: + # fallback to base class implementation + return super().forward(text, **kwargs) + + fragment_weights = kwargs[self.fragment_weights_key] + # self.transformer doesn't like receiving "fragment_weights" as an argument + kwargs.pop(self.fragment_weights_key) + + should_return_tokens = False + if self.return_tokens_key in kwargs: + should_return_tokens = kwargs.get(self.return_tokens_key, False) + # self.transformer doesn't like having extra kwargs + kwargs.pop(self.return_tokens_key) + + batch_z = None + batch_tokens = None + for fragments, weights in zip(text, fragment_weights): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) + batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) + + # should have shape (B, 77, 768) + #print(f"assembled all tokens into tensor of shape {batch_z.shape}") + + if should_return_tokens: + return batch_z, batch_tokens + else: + return batch_z + + def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: + tokens = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + if include_start_and_end_markers: + return tokens + else: + return [x[1:-1] for x in tokens] + + + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) + return torch.sum(embeddings * reshaped_weights, dim=1) + # lerped embeddings has shape (77, 768) + + + def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): + ''' + + :param fragments: + :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. + :return: + ''' + # empty is meaningful + if len(fragments) == 0 and len(weights) == 0: + fragments = [''] + weights = [1] + item_encodings = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=True, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + all_tokens = [] + per_token_weights = [] + #print("all fragments:", fragments, weights) + for index, fragment in enumerate(item_encodings): + weight = weights[index] + #print("processing fragment", fragment, weight) + fragment_tokens = item_encodings[index] + #print("fragment", fragment, "processed to", fragment_tokens) + # trim bos and eos markers before appending + all_tokens.extend(fragment_tokens[1:-1]) + per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + + if (len(all_tokens) + 2) > self.max_length: + excess_token_count = (len(all_tokens) + 2) - self.max_length + print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") + all_tokens = all_tokens[:self.max_length - 2] + per_token_weights = per_token_weights[:self.max_length - 2] + + # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] + # (77 = self.max_length) + pad_length = self.max_length - 1 - len(all_tokens) + all_tokens.insert(0, self.tokenizer.bos_token_id) + all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) + per_token_weights.insert(0, 1) + per_token_weights.extend([1] * pad_length) + + all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) + #print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + return all_tokens_tensor, per_token_weights_tensor + + def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + ''' + Build a tensor representing the passed-in tokens, each of which has a weight. + :param tokens: A tensor of shape (77) containing token ids (integers) + :param per_token_weights: A tensor of shape (77) containing weights (floats) + :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector + :param kwargs: passed on to self.transformer() + :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. + ''' + #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + + if weight_delta_from_empty: + empty_tokens = self.tokenizer([''] * z.shape[0], + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + )['input_ids'].to(self.device) + empty_z = self.transformer(input_ids=empty_tokens, **kwargs) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + weighted_z_delta_from_empty = (weighted_z-empty_z) + #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + + #print("using empty-delta method, first 5 rows:") + #print(weighted_z[:5]) + + return weighted_z + + else: + original_mean = z.mean() + z *= batch_weights_expanded + after_weighting_mean = z.mean() + # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does + mean_correction_factor = original_mean/after_weighting_mean + z *= mean_correction_factor + return z + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/scripts/invoke.py b/scripts/invoke.py index faa85de80e..5094473d33 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -18,6 +18,7 @@ from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from omegaconf import OmegaConf from pathlib import Path +from pyparsing import ParseException # global used in multiple functions (fix) infile = None @@ -172,8 +173,7 @@ def main_loop(gen, opt): pass if len(opt.prompt) == 0: - print('\nTry again with a prompt!') - continue + opt.prompt = '' # width and height are set by model if not specified if not opt.width: @@ -328,12 +328,16 @@ def main_loop(gen, opt): if operation == 'generate': catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts opt.last_operation='generate' - gen.prompt2image( - image_callback=image_writer, - step_callback=step_callback, - catch_interrupts=catch_ctrl_c, - **vars(opt) - ) + try: + gen.prompt2image( + image_callback=image_writer, + step_callback=step_callback, + catch_interrupts=catch_ctrl_c, + **vars(opt) + ) + except ParseException as e: + print('** An error occurred while processing your prompt **') + print(f'** {str(e)} **') elif operation == 'postprocess': print(f'>> fixing {opt.prompt}') opt.last_operation = do_postprocess(gen,opt,image_writer) @@ -528,12 +532,8 @@ def del_config(model_name:str, gen, opt, completer): if model_name == current_model: print("** Can't delete active model. !switch to another model first. **") return - yaml_str = gen.model_cache.del_model(model_name) - - tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp') - with open(tmpfile, 'w') as outfile: - outfile.write(yaml_str) - os.rename(tmpfile,opt.conf) + if gen.model_cache.del_model(model_name): + gen.model_cache.commit(opt.conf) print(f'** {model_name} deleted') completer.del_model(model_name) @@ -592,7 +592,9 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak def do_textmask(gen, opt, callback): image_path = opt.prompt - assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **' + if not os.path.exists(image_path): + image_path = os.path.join(opt.outdir,image_path) + assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **' tm = opt.text_mask[0] threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..03db9274d4 --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,440 @@ +import unittest + +import pyparsing + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ + Fragment + + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse_conjunction(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +def make_basic_conjunction(strings: list[str]): + fragments = [Fragment(x) for x in strings] + return Conjunction([FlattenedPrompt(fragments)]) + +def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]): + fragments = [Fragment(x, w) for x,w in weighted_strings] + return Conjunction([FlattenedPrompt(fragments)]) + + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)")) + self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) + self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating")) + self.assertEqual(make_basic_conjunction(['Dalí']), parse_prompt("Dalí")) + + def test_attention(self): + self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) + self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-")) + self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5")) + self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++")) + self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++")) + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]), + parse_prompt("(pretty flowers)+")) + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]), + parse_prompt("(pretty flowers)+, the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++")) + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + + #print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++")) + + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)")) + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), + FlattenedPrompt([("a person with a hat", 1.0), + ("riding a", 1.1*1.1), + CrossAttentionControlSubstitute( + [Fragment("bicycle", pow(1.1,2))], + [Fragment("skateboard", pow(1.1,2))]) + ]) + ], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + assert_if_prompt_string_not_untouched('a badly formed +test prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test prompt') + #with self.assertRaises(pyparsing.ParseException): + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed +test prompt') + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt')) + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(((a badly (formed +test )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed +test prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed +test +prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('("((a badly (formed +test ").blend(1.0)') + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), + parse_prompt("hamburger ((bun))")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), + parse_prompt("hamburger (bun)")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]), + parse_prompt("hamburger (kaiser roll)")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]), + parse_prompt("hamburger ((kaiser roll))")) + + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]), + FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]), + parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)') + ) + + + def test_nested(self): + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), + parse_prompt('fire (flames (trees)1.5)2.0')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)')) + + def test_cross_attention_control(self): + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog")) + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog")) + + + fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) + + fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])]) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) + + trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])]) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) + + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]), + (', fire', 1.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) + + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape "".swap("in winter")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape " ".swap("in winter")')) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap("")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap()')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap(" ")')) + + def test_cross_attention_control_with_attention(self): + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)")) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)")) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)")) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(0.9,1))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)")) + + + def test_cross_attention_control_options(self): + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog")) + + self.assertEqual( + Conjunction([ + FlattenedPrompt([Fragment('a fantasy forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)], + options={'s_start': 0.8, 't_start': 0.8})])]), + parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)")) + + + def test_escaping(self): + + # make sure ", ( and ) can be escaped + + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)')) + self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)')) + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-')) + + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+')) + + self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry ')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( ')) + + def test_cross_attention_escaping(self): + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (man).swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (\\"man).swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + def test_legacy_blend(self): + pp = PromptParser() + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain man:1 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man: man mountain-:')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.75,0.25]), + pp.parse_legacy_blend('mountain man:3 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[1.0,0.0]), + pp.parse_legacy_blend('mountain man:3 man mountain:0')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.8,0.2]), + pp.parse_legacy_blend('"mountain man":4 man mountain')) + + + def test_single(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/validate_pr_prompt.txt b/tests/validate_pr_prompt.txt new file mode 100644 index 0000000000..c1a8e00cbe --- /dev/null +++ b/tests/validate_pr_prompt.txt @@ -0,0 +1 @@ +banana sushi -Ak_lms -S42 -s10