mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
8c073a7818
161
.github/workflows/test-invoke-conda.yml
vendored
161
.github/workflows/test-invoke-conda.yml
vendored
@ -1,161 +0,0 @@
|
||||
name: Test invoke.py
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
branches:
|
||||
- 'main'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
- 'converted_to_draft'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
fail_if_pull_request_is_draft:
|
||||
if: github.event.pull_request.draft == true
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Fails in order to indicate that pull request needs to be marked as ready to review and unit tests workflow needs to pass.
|
||||
run: exit 1
|
||||
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
stable-diffusion-model:
|
||||
- 'stable-diffusion-1.5'
|
||||
environment-yaml:
|
||||
- environment-lin-amd.yml
|
||||
- environment-lin-cuda.yml
|
||||
- environment-mac.yml
|
||||
- environment-win-cuda.yml
|
||||
include:
|
||||
- environment-yaml: environment-lin-amd.yml
|
||||
os: ubuntu-22.04
|
||||
curl-command: curl
|
||||
github-env: $GITHUB_ENV
|
||||
default-shell: bash -l {0}
|
||||
- environment-yaml: environment-lin-cuda.yml
|
||||
os: ubuntu-22.04
|
||||
curl-command: curl
|
||||
github-env: $GITHUB_ENV
|
||||
default-shell: bash -l {0}
|
||||
- environment-yaml: environment-mac.yml
|
||||
os: macos-12
|
||||
curl-command: curl
|
||||
github-env: $GITHUB_ENV
|
||||
default-shell: bash -l {0}
|
||||
- environment-yaml: environment-win-cuda.yml
|
||||
os: windows-2022
|
||||
curl-command: curl.exe
|
||||
github-env: $env:GITHUB_ENV
|
||||
default-shell: pwsh
|
||||
- stable-diffusion-model: stable-diffusion-1.5
|
||||
stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
|
||||
stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1
|
||||
stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt
|
||||
name: ${{ matrix.environment-yaml }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
CONDA_ENV_NAME: invokeai
|
||||
INVOKEAI_ROOT: '${{ github.workspace }}/invokeai'
|
||||
defaults:
|
||||
run:
|
||||
shell: ${{ matrix.default-shell }}
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: create models.yaml from example
|
||||
run: |
|
||||
mkdir -p ${{ env.INVOKEAI_ROOT }}/configs
|
||||
cp configs/models.yaml.example ${{ env.INVOKEAI_ROOT }}/configs/models.yaml
|
||||
|
||||
- name: create environment.yml
|
||||
run: cp "environments-and-requirements/${{ matrix.environment-yaml }}" environment.yml
|
||||
|
||||
- name: Use cached conda packages
|
||||
id: use-cached-conda-packages
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/conda_pkgs_dir
|
||||
key: conda-pkgs-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles(matrix.environment-yaml) }}
|
||||
|
||||
- name: Activate Conda Env
|
||||
id: activate-conda-env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
activate-environment: ${{ env.CONDA_ENV_NAME }}
|
||||
environment-file: environment.yml
|
||||
miniconda-version: latest
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: set test prompt to development branch validation
|
||||
if: ${{ github.ref == 'refs/heads/development' }}
|
||||
run: echo "TEST_PROMPTS=tests/dev_prompts.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: set test prompt to Pull Request validation
|
||||
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: Use Cached Stable Diffusion Model
|
||||
id: cache-sd-model
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-${{ matrix.stable-diffusion-model }}
|
||||
with:
|
||||
path: ${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}
|
||||
key: ${{ env.cache-name }}
|
||||
|
||||
- name: Download ${{ matrix.stable-diffusion-model }}
|
||||
id: download-stable-diffusion-model
|
||||
if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}"
|
||||
${{ matrix.curl-command }} -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" -L ${{ matrix.stable-diffusion-model-url }}
|
||||
|
||||
- name: run configure_invokeai.py
|
||||
id: run-preload-models
|
||||
run: |
|
||||
python scripts/configure_invokeai.py --skip-sd-weights --yes
|
||||
|
||||
- name: cat invokeai.init
|
||||
id: cat-invokeai
|
||||
run: cat ${{ env.INVOKEAI_ROOT }}/invokeai.init
|
||||
|
||||
- name: Run the tests
|
||||
id: run-tests
|
||||
if: matrix.os != 'windows-2022'
|
||||
run: |
|
||||
time python scripts/invoke.py \
|
||||
--no-patchmatch \
|
||||
--no-nsfw_checker \
|
||||
--model ${{ matrix.stable-diffusion-model }} \
|
||||
--from_file ${{ env.TEST_PROMPTS }} \
|
||||
--root="${{ env.INVOKEAI_ROOT }}" \
|
||||
--outdir="${{ env.INVOKEAI_ROOT }}/outputs"
|
||||
|
||||
- name: export conda env
|
||||
id: export-conda-env
|
||||
if: matrix.os != 'windows-2022'
|
||||
run: |
|
||||
mkdir -p outputs/img-samples
|
||||
conda env export --name ${{ env.CONDA_ENV_NAME }} > ${{ env.INVOKEAI_ROOT }}/outputs/environment-${{ runner.os }}-${{ runner.arch }}.yml
|
||||
|
||||
- name: Archive results
|
||||
if: matrix.os != 'windows-2022'
|
||||
id: archive-results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results_${{ matrix.requirements-file }}_${{ matrix.python-version }}
|
||||
path: ${{ env.INVOKEAI_ROOT }}/outputs
|
81
.github/workflows/test-invoke-pip.yml
vendored
81
.github/workflows/test-invoke-pip.yml
vendored
@ -4,8 +4,6 @@ on:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
branches:
|
||||
- 'main'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@ -17,14 +15,14 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
fail_if_pull_request_is_draft:
|
||||
if: github.event.pull_request.draft == true
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Fails in order to indicate that pull request needs to be marked as ready to review and unit tests workflow needs to pass.
|
||||
run: exit 1
|
||||
# fail_if_pull_request_is_draft:
|
||||
# if: github.event.pull_request.draft == true && github.head_ref != 'dev/diffusers'
|
||||
# runs-on: ubuntu-18.04
|
||||
# steps:
|
||||
# - name: Fails in order to indicate that pull request needs to be marked as ready to review and unit tests workflow needs to pass.
|
||||
# run: exit 1
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
if: github.event.pull_request.draft == false || github.head_ref == 'dev/diffusers'
|
||||
strategy:
|
||||
matrix:
|
||||
stable-diffusion-model:
|
||||
@ -40,26 +38,23 @@ jobs:
|
||||
include:
|
||||
- requirements-file: requirements-lin-cuda.txt
|
||||
os: ubuntu-22.04
|
||||
curl-command: curl
|
||||
github-env: $GITHUB_ENV
|
||||
- requirements-file: requirements-lin-amd.txt
|
||||
os: ubuntu-22.04
|
||||
curl-command: curl
|
||||
github-env: $GITHUB_ENV
|
||||
- requirements-file: requirements-mac-mps-cpu.txt
|
||||
os: macOS-12
|
||||
curl-command: curl
|
||||
github-env: $GITHUB_ENV
|
||||
- requirements-file: requirements-win-colab-cuda.txt
|
||||
os: windows-2022
|
||||
curl-command: curl.exe
|
||||
github-env: $env:GITHUB_ENV
|
||||
- stable-diffusion-model: stable-diffusion-1.5
|
||||
stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
|
||||
stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1
|
||||
stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt
|
||||
name: ${{ matrix.requirements-file }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
INVOKE_MODEL_RECONFIGURE: '--yes'
|
||||
INVOKEAI_ROOT: '${{ github.workspace }}/invokeai'
|
||||
PYTHONUNBUFFERED: 1
|
||||
HAVE_SECRETS: ${{ secrets.HUGGINGFACE_TOKEN != '' }}
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
id: checkout-sources
|
||||
@ -77,10 +72,17 @@ jobs:
|
||||
echo "INVOKEAI_ROOT=${{ github.workspace }}/invokeai" >> ${{ matrix.github-env }}
|
||||
echo "INVOKEAI_OUTDIR=${{ github.workspace }}/invokeai/outputs" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: create models.yaml from example
|
||||
run: |
|
||||
mkdir -p ${{ env.INVOKEAI_ROOT }}/configs
|
||||
cp configs/models.yaml.example ${{ env.INVOKEAI_ROOT }}/configs/models.yaml
|
||||
- name: Use Cached diffusers-1.5
|
||||
id: cache-sd-model
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: huggingface-${{ matrix.stable-diffusion-model }}
|
||||
with:
|
||||
path: |
|
||||
${{ env.INVOKEAI_ROOT }}/models/runwayml
|
||||
${{ env.INVOKEAI_ROOT }}/models/stabilityai
|
||||
${{ env.INVOKEAI_ROOT }}/models/CompVis
|
||||
key: ${{ env.cache-name }}
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
@ -110,30 +112,31 @@ jobs:
|
||||
- name: install requirements
|
||||
run: pip3 install -r '${{ matrix.requirements-file }}'
|
||||
|
||||
- name: Use Cached Stable Diffusion Model
|
||||
id: cache-sd-model
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-${{ matrix.stable-diffusion-model }}
|
||||
with:
|
||||
path: ${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}
|
||||
key: ${{ env.cache-name }}
|
||||
|
||||
- name: Download ${{ matrix.stable-diffusion-model }}
|
||||
id: download-stable-diffusion-model
|
||||
if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}"
|
||||
${{ matrix.curl-command }} -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" -L ${{ matrix.stable-diffusion-model-url }}
|
||||
|
||||
- name: run configure_invokeai.py
|
||||
id: run-preload-models
|
||||
run: python3 scripts/configure_invokeai.py --skip-sd-weights --yes
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||
run: >
|
||||
configure_invokeai.py
|
||||
--yes
|
||||
--full-precision # can't use fp16 weights without a GPU
|
||||
|
||||
- name: Run the tests
|
||||
id: run-tests
|
||||
if: matrix.os != 'windows-2022'
|
||||
run: python3 scripts/invoke.py --no-patchmatch --no-nsfw_checker --model ${{ matrix.stable-diffusion-model }} --from_file ${{ env.TEST_PROMPTS }} --root="${{ env.INVOKEAI_ROOT }}" --outdir="${{ env.INVOKEAI_OUTDIR }}"
|
||||
env:
|
||||
# Set offline mode to make sure configure preloaded successfully.
|
||||
HF_HUB_OFFLINE: 1
|
||||
HF_DATASETS_OFFLINE: 1
|
||||
TRANSFORMERS_OFFLINE: 1
|
||||
run: >
|
||||
python3 scripts/invoke.py
|
||||
--no-patchmatch
|
||||
--no-nsfw_checker
|
||||
--model ${{ matrix.stable-diffusion-model }}
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
--root="${{ env.INVOKEAI_ROOT }}"
|
||||
--outdir="${{ env.INVOKEAI_OUTDIR }}"
|
||||
|
||||
- name: Archive results
|
||||
id: archive-results
|
||||
|
96
README.md
96
README.md
@ -8,12 +8,10 @@
|
||||
|
||||
[![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link]
|
||||
|
||||
[![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link]
|
||||
[![CI checks on main badge]][CI checks on main link] [![latest commit to main badge]][latest commit to main link]
|
||||
|
||||
[![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link]
|
||||
|
||||
[CI checks on dev badge]: https://flat.badgen.net/github/checks/invoke-ai/InvokeAI/development?label=CI%20status%20on%20dev&cache=900&icon=github
|
||||
[CI checks on dev link]: https://github.com/invoke-ai/InvokeAI/actions?query=branch%3Adevelopment
|
||||
[CI checks on main badge]: https://flat.badgen.net/github/checks/invoke-ai/InvokeAI/main?label=CI%20status%20on%20main&cache=900&icon=github
|
||||
[CI checks on main link]: https://github.com/invoke-ai/InvokeAI/actions/workflows/test-invoke-conda.yml
|
||||
[discord badge]: https://flat.badgen.net/discord/members/ZmtBAhwWhy?icon=discord
|
||||
@ -26,19 +24,13 @@
|
||||
[github open prs link]: https://github.com/invoke-ai/InvokeAI/pulls?q=is%3Apr+is%3Aopen
|
||||
[github stars badge]: https://flat.badgen.net/github/stars/invoke-ai/InvokeAI?icon=github
|
||||
[github stars link]: https://github.com/invoke-ai/InvokeAI/stargazers
|
||||
[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/invoke-ai/InvokeAI/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900
|
||||
[latest commit to dev link]: https://github.com/invoke-ai/InvokeAI/commits/development
|
||||
[latest commit to main badge]: https://flat.badgen.net/github/last-commit/invoke-ai/InvokeAI/main?icon=github&color=yellow&label=last%20dev%20commit&cache=900
|
||||
[latest commit to main link]: https://github.com/invoke-ai/InvokeAI/commits/main
|
||||
[latest release badge]: https://flat.badgen.net/github/release/invoke-ai/InvokeAI/development?icon=github
|
||||
[latest release link]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
</div>
|
||||
|
||||
This is a fork of
|
||||
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion),
|
||||
the open source text-to-image generator. It provides a streamlined
|
||||
process with various new features and options to aid the image
|
||||
generation process. It runs on Windows, macOS and Linux machines, with
|
||||
GPU cards with as little as 4 GB of RAM. It provides both a polished
|
||||
Web interface (see below), and an easy-to-use command-line interface.
|
||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to Install](#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||
|
||||
@ -46,6 +38,9 @@ _Note: InvokeAI is rapidly evolving. Please use the
|
||||
[Issues](https://github.com/invoke-ai/InvokeAI/issues) tab to report bugs and make feature
|
||||
requests. Be sure to use the provided templates. They will help us diagnose issues faster._
|
||||
|
||||
|
||||

|
||||
|
||||
# Getting Started with InvokeAI
|
||||
|
||||
For full installation and upgrade instructions, please see:
|
||||
@ -58,10 +53,7 @@ For full installation and upgrade instructions, please see:
|
||||
5. Wait a while, until it is done.
|
||||
6. The folder where you ran the installer from will now be filled with lots of files. If you are on Windows, double-click on the `invoke.bat` file. On macOS, open a Terminal window, drag `invoke.sh` from the folder into the Terminal, and press return. On Linux, run `invoke.sh`
|
||||
7. Press 2 to open the "browser-based UI", press enter/return, wait a minute or two for Stable Diffusion to start up, then open your browser and go to http://localhost:9090.
|
||||
8. Type `banana sushi` in the box on the top left and click `Invoke`:
|
||||
|
||||
<div align="center"><img src="docs/assets/invoke-web-server-1.png" width=640></div>
|
||||
|
||||
8. Type `banana sushi` in the box on the top left and click `Invoke`
|
||||
|
||||
|
||||
## Table of Contents
|
||||
@ -76,7 +68,7 @@ For full installation and upgrade instructions, please see:
|
||||
8. [Support](#support)
|
||||
9. [Further Reading](#further-reading)
|
||||
|
||||
### Installation
|
||||
## Installation
|
||||
|
||||
This fork is supported across Linux, Windows and Macintosh. Linux
|
||||
users can use either an Nvidia-based card (with CUDA support) or an
|
||||
@ -108,52 +100,42 @@ to render 512x512 images.
|
||||
|
||||
- At least 12 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
||||
|
||||
**Note**
|
||||
## Features
|
||||
|
||||
If you have a Nvidia 10xx series card (e.g. the 1080ti), please
|
||||
run the dream script in full-precision mode as shown below.
|
||||
Feature documentation can be reviewed by navigating to [the InvokeAI Documentation page](https://invoke-ai.github.io/InvokeAI/features/)
|
||||
|
||||
Similarly, specify full-precision mode on Apple M1 hardware.
|
||||
### *Web Server & UI*
|
||||
InvokeAI offers a locally hosted Web Server & React Frontend, with an industry leading user experience. The Web-based UI allows for simple and intuitive workflows, and is responsive for use on mobile devices and tablets accessing the web server.
|
||||
|
||||
Precision is auto configured based on the device. If however you encounter
|
||||
errors like 'expected type Float but found Half' or 'not implemented for Half'
|
||||
you can try starting `invoke.py` with the `--precision=float32` flag to your initialization command
|
||||
### *Unified Canvas*
|
||||
The Unified Canvas is a fully integrated canvas implementation with support for all core generation capabilities, in/outpainting, brush tools, and more. This creative tool unlocks the capability for artists to create with AI as a creative collaborator, and can be used to augment AI-generated imagery, sketches, photography, renders, and more.
|
||||
|
||||
```bash
|
||||
(invokeai) ~/InvokeAI$ python scripts/invoke.py --precision=float32
|
||||
```
|
||||
Or by updating your InvokeAI configuration file with this argument.
|
||||
### *Advanced Prompt Syntax*
|
||||
InvokeAI's advanced prompt syntax allows for token weighting, cross-attention control, and prompt blending, allowing for fine-tuned tweaking of your invocations and exploration of the latent space.
|
||||
|
||||
### Features
|
||||
### *Command Line Interface*
|
||||
For users utilizing a terminal-based environment, or who want to take advantage of CLI features, InvokeAI offers an extensive and actively supported command-line interface that provides the full suite of generation functionality available in the tool.
|
||||
|
||||
#### Major Features
|
||||
### Other features
|
||||
- *Support for both ckpt and diffusers models*
|
||||
- *SD 2.0, 2.1 support*
|
||||
- *Noise Control & Tresholding*
|
||||
- *Popular Sampler Support*
|
||||
- *Upscaling & Face Restoration Tools*
|
||||
- *Embedding Manager & Support*
|
||||
- *Model Manager & Support*
|
||||
|
||||
- [Web Server](https://invoke-ai.github.io/InvokeAI/features/WEB/)
|
||||
- [Interactive Command Line Interface](https://invoke-ai.github.io/InvokeAI/features/CLI/)
|
||||
- [Image To Image](https://invoke-ai.github.io/InvokeAI/features/IMG2IMG/)
|
||||
- [Inpainting Support](https://invoke-ai.github.io/InvokeAI/features/INPAINTING/)
|
||||
- [Outpainting Support](https://invoke-ai.github.io/InvokeAI/features/OUTPAINTING/)
|
||||
- [Upscaling, face-restoration and outpainting](https://invoke-ai.github.io/InvokeAI/features/POSTPROCESS/)
|
||||
- [Reading Prompts From File](https://invoke-ai.github.io/InvokeAI/features/PROMPTS/#reading-prompts-from-a-file)
|
||||
- [Prompt Blending](https://invoke-ai.github.io/InvokeAI/features/PROMPTS/#prompt-blending)
|
||||
- [Thresholding and Perlin Noise Initialization Options](https://invoke-ai.github.io/InvokeAI/features/OTHER/#thresholding-and-perlin-noise-initialization-options)
|
||||
- [Negative/Unconditioned Prompts](https://invoke-ai.github.io/InvokeAI/features/PROMPTS/#negative-and-unconditioned-prompts)
|
||||
- [Variations](https://invoke-ai.github.io/InvokeAI/features/VARIATIONS/)
|
||||
- [Personalizing Text-to-Image Generation](https://invoke-ai.github.io/InvokeAI/features/TEXTUAL_INVERSION/)
|
||||
- [Simplified API for text to image generation](https://invoke-ai.github.io/InvokeAI/features/OTHER/#simplified-api)
|
||||
|
||||
#### Other Features
|
||||
|
||||
- [Google Colab](https://invoke-ai.github.io/InvokeAI/features/OTHER/#google-colab)
|
||||
- [Seamless Tiling](https://invoke-ai.github.io/InvokeAI/features/OTHER/#seamless-tiling)
|
||||
- [Shortcut: Reusing Seeds](https://invoke-ai.github.io/InvokeAI/features/OTHER/#shortcuts-reusing-seeds)
|
||||
- [Preload Models](https://invoke-ai.github.io/InvokeAI/features/OTHER/#preload-models)
|
||||
### Coming Soon
|
||||
- *Node-Based Architecture & UI*
|
||||
- And more...
|
||||
|
||||
### Latest Changes
|
||||
|
||||
For our latest changes, view our [Release Notes](https://github.com/invoke-ai/InvokeAI/releases)
|
||||
For our latest changes, view our [Release
|
||||
Notes](https://github.com/invoke-ai/InvokeAI/releases) and the
|
||||
[CHANGELOG](docs/CHANGELOG.md).
|
||||
|
||||
### Troubleshooting
|
||||
## Troubleshooting
|
||||
|
||||
Please check out our **[Q&A](https://invoke-ai.github.io/InvokeAI/help/TROUBLESHOOT/#faq)** to get solutions for common installation
|
||||
problems and other issues.
|
||||
@ -183,13 +165,7 @@ their time, hard work and effort.
|
||||
|
||||
### Support
|
||||
|
||||
For support, please use this repository's GitHub Issues tracking service. Feel free to send me an
|
||||
email if you use and like the script.
|
||||
For support, please use this repository's GitHub Issues tracking service, or join the Discord.
|
||||
|
||||
Original portions of the software are Copyright (c) 2022
|
||||
[Lincoln D. Stein](https://github.com/lstein)
|
||||
Original portions of the software are Copyright (c) 2023 by respective contributors.
|
||||
|
||||
### Further Reading
|
||||
|
||||
Please see the original README for more information on this software and underlying algorithm,
|
||||
located in the file [README-CompViz.md](https://invoke-ai.github.io/InvokeAI/other/README-CompViz/).
|
||||
|
@ -1,35 +1,34 @@
|
||||
import eventlet
|
||||
import base64
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import mimetypes
|
||||
import traceback
|
||||
import math
|
||||
import io
|
||||
import base64
|
||||
import os
|
||||
import json
|
||||
from threading import Event
|
||||
from uuid import uuid4
|
||||
|
||||
from werkzeug.utils import secure_filename
|
||||
import eventlet
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||
from flask_socketio import SocketIO
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Image as ImageType
|
||||
from uuid import uuid4
|
||||
from threading import Event
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
from backend.modules.get_canvas_generation_mode import (
|
||||
get_canvas_generation_mode,
|
||||
)
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
|
||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
||||
|
||||
# Loading Arguments
|
||||
opt = Args()
|
||||
@ -304,7 +303,7 @@ class InvokeAIWebServer:
|
||||
def handle_request_capabilities():
|
||||
print(f">> System config requested")
|
||||
config = self.get_system_config()
|
||||
config["model_list"] = self.generate.model_cache.list_models()
|
||||
config["model_list"] = self.generate.model_manager.list_models()
|
||||
config["infill_methods"] = infill_methods()
|
||||
socketio.emit("systemConfig", config)
|
||||
|
||||
@ -317,11 +316,11 @@ class InvokeAIWebServer:
|
||||
{'search_folder': None, 'found_models': None},
|
||||
)
|
||||
else:
|
||||
search_folder, found_models = self.generate.model_cache.search_models(search_folder)
|
||||
search_folder, found_models = self.generate.model_manager.search_models(search_folder)
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{'search_folder': search_folder, 'found_models': found_models},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.socketio.emit("error", {"message": (str(e))})
|
||||
print("\n")
|
||||
@ -335,18 +334,20 @@ class InvokeAIWebServer:
|
||||
model_name = new_model_config['name']
|
||||
del new_model_config['name']
|
||||
model_attributes = new_model_config
|
||||
if len(model_attributes['vae']) == 0:
|
||||
del model_attributes['vae']
|
||||
update = False
|
||||
current_model_list = self.generate.model_cache.list_models()
|
||||
current_model_list = self.generate.model_manager.list_models()
|
||||
if model_name in current_model_list:
|
||||
update = True
|
||||
|
||||
print(f">> Adding New Model: {model_name}")
|
||||
|
||||
self.generate.model_cache.add_model(
|
||||
self.generate.model_manager.add_model(
|
||||
model_name=model_name, model_attributes=model_attributes, clobber=True)
|
||||
self.generate.model_cache.commit(opt.conf)
|
||||
self.generate.model_manager.commit(opt.conf)
|
||||
|
||||
new_model_list = self.generate.model_cache.list_models()
|
||||
new_model_list = self.generate.model_manager.list_models()
|
||||
socketio.emit(
|
||||
"newModelAdded",
|
||||
{"new_model_name": model_name,
|
||||
@ -364,9 +365,9 @@ class InvokeAIWebServer:
|
||||
def handle_delete_model(model_name: str):
|
||||
try:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
self.generate.model_cache.del_model(model_name)
|
||||
self.generate.model_cache.commit(opt.conf)
|
||||
updated_model_list = self.generate.model_cache.list_models()
|
||||
self.generate.model_manager.del_model(model_name)
|
||||
self.generate.model_manager.commit(opt.conf)
|
||||
updated_model_list = self.generate.model_manager.list_models()
|
||||
socketio.emit(
|
||||
"modelDeleted",
|
||||
{"deleted_model_name": model_name,
|
||||
@ -385,7 +386,7 @@ class InvokeAIWebServer:
|
||||
try:
|
||||
print(f">> Model change requested: {model_name}")
|
||||
model = self.generate.set_model(model_name)
|
||||
model_list = self.generate.model_cache.list_models()
|
||||
model_list = self.generate.model_manager.list_models()
|
||||
if model is None:
|
||||
socketio.emit(
|
||||
"modelChangeFailed",
|
||||
@ -797,7 +798,7 @@ class InvokeAIWebServer:
|
||||
|
||||
# App Functions
|
||||
def get_system_config(self):
|
||||
model_list: dict = self.generate.model_cache.list_models()
|
||||
model_list: dict = self.generate.model_manager.list_models()
|
||||
active_model_name = None
|
||||
|
||||
for model_name, model_dict in model_list.items():
|
||||
@ -1205,9 +1206,16 @@ class InvokeAIWebServer:
|
||||
|
||||
print(generation_parameters)
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return image_progress(progress_state.latents, progress_state.step)
|
||||
else:
|
||||
return image_progress(*cb_args, **kwargs)
|
||||
|
||||
self.generate.prompt2image(
|
||||
**generation_parameters,
|
||||
step_callback=image_progress,
|
||||
step_callback=diffusers_step_callback_adapter,
|
||||
image_callback=image_done
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,8 @@ SAMPLER_CHOICES = [
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"pndm",
|
||||
]
|
||||
|
||||
|
||||
|
@ -2,9 +2,10 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
--trusted-host https://download.pytorch.org
|
||||
accelerate~=0.14
|
||||
accelerate~=0.15
|
||||
albumentations
|
||||
diffusers
|
||||
diffusers[torch]~=0.11
|
||||
einops
|
||||
eventlet
|
||||
flask_cors
|
||||
flask_socketio
|
||||
|
@ -1,72 +1,76 @@
|
||||
stable-diffusion-1.5:
|
||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||
description: Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-v1-5
|
||||
config: v1-inference.yaml
|
||||
file: v1-5-pruned-emaonly.ckpt
|
||||
recommended: true
|
||||
width: 512
|
||||
height: 512
|
||||
format: diffusers
|
||||
recommended: True
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
default: True
|
||||
stable-diffusion-2.1:
|
||||
description: Stable Diffusion version 2.1 diffusers model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-1
|
||||
format: diffusers
|
||||
recommended: True
|
||||
inpainting-1.5:
|
||||
description: RunwayML SD 1.5 model optimized for inpainting (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-inpainting
|
||||
config: v1-inpainting-inference.yaml
|
||||
file: sd-v1-5-inpainting.ckpt
|
||||
recommended: True
|
||||
width: 512
|
||||
height: 512
|
||||
ft-mse-improved-autoencoder-840000:
|
||||
description: StabilityAI improved autoencoder fine-tuned for human faces (recommended; 335 MB)
|
||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||
config: VAE/default
|
||||
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||
format: ckpt
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||
recommended: True
|
||||
width: 512
|
||||
height: 512
|
||||
stable-diffusion-1.4:
|
||||
description: The original Stable Diffusion version 1.4 weight file (4.27 GB)
|
||||
repo_id: CompVis/stable-diffusion-v-1-4-original
|
||||
config: v1-inference.yaml
|
||||
file: sd-v1-4.ckpt
|
||||
repo_id: CompVis/stable-diffusion-v1-4
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
format: diffusers
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
waifu-diffusion-1.4:
|
||||
description: Waifu diffusion 1.4
|
||||
format: diffusers
|
||||
repo_id: hakurei/waifu-diffusion
|
||||
waifu-diffusion-1.3:
|
||||
description: Stable Diffusion 1.4 fine tuned on anime-styled images (4.27 GB)
|
||||
repo_id: hakurei/waifu-diffusion-v1-3
|
||||
config: v1-inference.yaml
|
||||
file: model-epoch09-float32.ckpt
|
||||
format: ckpt
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
trinart-2.0:
|
||||
description: An SD model finetuned with ~40,000 assorted high resolution manga/anime-style pictures (2.13 GB)
|
||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
||||
config: v1-inference.yaml
|
||||
file: trinart2_step95000.ckpt
|
||||
format: diffusers
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
trinart_characters-1.0:
|
||||
description: An SD model finetuned with 19.2M anime/manga style images (2.13 GB)
|
||||
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
trinart_characters-2.0:
|
||||
description: An SD model finetuned with 19.2M anime/manga style images (4.27 GB)
|
||||
repo_id: naclbit/trinart_derrida_characters_v2_stable_diffusion
|
||||
config: v1-inference.yaml
|
||||
file: trinart_characters_it4_v1.ckpt
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
trinart_vae:
|
||||
description: Custom autoencoder for trinart_characters
|
||||
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
||||
config: VAE/trinart
|
||||
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||
file: derrida_final.ckpt
|
||||
format: ckpt
|
||||
vae:
|
||||
repo_id: naclbit/trinart_derrida_characters_v2_stable_diffusion
|
||||
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
papercut-1.0:
|
||||
description: SD 1.5 fine-tuned for papercut art (use "PaperCut" in your prompts) (2.13 GB)
|
||||
repo_id: Fictiverse/Stable_Diffusion_PaperCut_Model
|
||||
config: v1-inference.yaml
|
||||
file: PaperCut_v1.ckpt
|
||||
format: diffusers
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
@ -75,6 +79,27 @@ voxel_art-1.0:
|
||||
repo_id: Fictiverse/Stable_Diffusion_VoxelArt_Model
|
||||
config: v1-inference.yaml
|
||||
file: VoxelArt_v1.ckpt
|
||||
format: ckpt
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
ft-mse-improved-autoencoder-840000:
|
||||
description: StabilityAI improved autoencoder fine-tuned for human faces. Use with legacy .ckpt models ONLY (335 MB)
|
||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||
format: ckpt
|
||||
config: VAE/default
|
||||
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
trinart_vae:
|
||||
description: Custom autoencoder for trinart_characters for legacy .ckpt models only (335 MB)
|
||||
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
||||
config: VAE/trinart
|
||||
format: ckpt
|
||||
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||
recommended: False
|
||||
width: 512
|
||||
height: 512
|
||||
|
@ -5,6 +5,25 @@
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
diffusers-1.4:
|
||||
description: 🤗🧨 Stable Diffusion v1.4
|
||||
format: diffusers
|
||||
repo_id: CompVis/stable-diffusion-v1-4
|
||||
diffusers-1.5:
|
||||
description: 🤗🧨 Stable Diffusion v1.5
|
||||
format: diffusers
|
||||
repo_id: runwayml/stable-diffusion-v1-5
|
||||
default: true
|
||||
diffusers-1.5+mse:
|
||||
description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE
|
||||
format: diffusers
|
||||
repo_id: runwayml/stable-diffusion-v1-5
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
diffusers-inpainting-1.5:
|
||||
description: 🤗🧨 inpainting for Stable Diffusion v1.5
|
||||
format: diffusers
|
||||
repo_id: runwayml/stable-diffusion-inpainting
|
||||
stable-diffusion-1.5:
|
||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
@ -12,7 +31,6 @@ stable-diffusion-1.5:
|
||||
width: 512
|
||||
height: 512
|
||||
vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
default: true
|
||||
stable-diffusion-1.4:
|
||||
description: Stable Diffusion inference model version 1.4
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
|
68
configs/stable-diffusion/v2-inference-v.yaml
Normal file
68
configs/stable-diffusion/v2-inference-v.yaml
Normal file
@ -0,0 +1,68 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
@ -4,6 +4,97 @@ title: Changelog
|
||||
|
||||
# :octicons-log-16: **Changelog**
|
||||
|
||||
## v2.3.0 <small>(15 January 2023)</small>
|
||||
|
||||
**Transition to diffusers
|
||||
|
||||
Version 2.3 provides support for both the traditional `.ckpt` weight
|
||||
checkpoint files as well as the HuggingFace `diffusers` format. This
|
||||
introduces several changes you should know about.
|
||||
|
||||
1. The models.yaml format has been updated. There are now two
|
||||
different type of configuration stanza. The traditional ckpt
|
||||
one will look like this, with a `format` of `ckpt` and a
|
||||
`weights` field that points to the absolute or ROOTDIR-relative
|
||||
location of the ckpt file.
|
||||
|
||||
```
|
||||
inpainting-1.5:
|
||||
description: RunwayML SD 1.5 model optimized for inpainting (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-inpainting
|
||||
format: ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
```
|
||||
|
||||
A configuration stanza for a diffusers model hosted at HuggingFace will look like this,
|
||||
with a `format` of `diffusers` and a `repo_id` that points to the
|
||||
repository ID of the model on HuggingFace:
|
||||
|
||||
```
|
||||
stable-diffusion-2.1:
|
||||
description: Stable Diffusion version 2.1 diffusers model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-1
|
||||
format: diffusers
|
||||
```
|
||||
|
||||
A configuration stanza for a diffuers model stored locally should
|
||||
look like this, with a `format` of `diffusers`, but a `path` field
|
||||
that points at the directory that contains `model_index.json`:
|
||||
|
||||
```
|
||||
waifu-diffusion:
|
||||
description: Latest waifu diffusion 1.4
|
||||
format: diffusers
|
||||
path: models/diffusers/hakurei-haifu-diffusion-1.4
|
||||
```
|
||||
|
||||
2. The format of the models directory has changed to mimic the
|
||||
HuggingFace cache directory. By default, diffusers models are
|
||||
now automatically downloaded and retrieved from the directory
|
||||
`ROOTDIR/models/diffusers`, while other models are stored in
|
||||
the directory `ROOTDIR/models/hub`. This organization is the
|
||||
same as that used by HuggingFace for its cache management.
|
||||
|
||||
This allows you to share diffusers and ckpt model files easily with
|
||||
other machine learning applications that use the HuggingFace
|
||||
libraries. To do this, set the environment variable HF_HOME
|
||||
before starting up InvokeAI to tell it what directory to
|
||||
cache models in. To tell InvokeAI to use the standard HuggingFace
|
||||
cache directory, you would set HF_HOME like this (Linux/Mac):
|
||||
|
||||
`export HF_HOME=~/.cache/hugging_face`
|
||||
|
||||
3. If you upgrade to InvokeAI 2.3.* from an earlier version, there
|
||||
will be a one-time migration from the old models directory format
|
||||
to the new one. You will see a message about this the first time
|
||||
you start `invoke.py`.
|
||||
|
||||
4. Both the front end back ends of the model manager have been
|
||||
rewritten to accommodate diffusers. You can import models using
|
||||
their local file path, using their URLs, or their HuggingFace
|
||||
repo_ids. On the command line, all these syntaxes work:
|
||||
|
||||
```
|
||||
!import_model stabilityai/stable-diffusion-2-1-base
|
||||
!import_model /opt/sd-models/sd-1.4.ckpt
|
||||
!import_model https://huggingface.co/Fictiverse/Stable_Diffusion_PaperCut_Model/blob/main/PaperCut_v1.ckpt
|
||||
```
|
||||
|
||||
**KNOWN BUGS (15 January 2023)
|
||||
|
||||
1. On CUDA systems, the 768 pixel stable-diffusion-2.0 and
|
||||
stable-diffusion-2.1 models can only be run as `diffusers` models
|
||||
when the `xformer` library is installed and configured. Without
|
||||
`xformers`, InvokeAI returns black images.
|
||||
|
||||
2. Inpainting and outpainting have regressed in quality.
|
||||
|
||||
Both these issues are being actively worked on.
|
||||
|
||||
## v2.2.4 <small>(11 December 2022)</small>
|
||||
|
||||
**the `invokeai` directory**
|
||||
|
BIN
docs/assets/canvas_preview.png
Normal file
BIN
docs/assets/canvas_preview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 142 KiB |
@ -12,17 +12,18 @@ title: Installing Manually
|
||||
|
||||
## Introduction
|
||||
|
||||
You have two choices for manual installation, the [first
|
||||
one](#PIP_method) uses basic Python virtual environment (`venv`)
|
||||
commands and the PIP package manager. The [second one](#Conda_method)
|
||||
based on the Anaconda3 package manager (`conda`). Both methods require
|
||||
you to enter commands on the terminal, also known as the "console".
|
||||
You have two choices for manual installation.
|
||||
The [first one](#pip-Install) uses basic Python virtual environment (`venv`)
|
||||
command and `pip` package manager.
|
||||
The [second one](#Conda-method) uses Anaconda3 package manager (`conda`).
|
||||
Both methods require you to enter commands on the terminal, also known as the
|
||||
"console".
|
||||
|
||||
Note that the conda install method is currently deprecated and will not
|
||||
be supported at some point in the future.
|
||||
Note that the `conda` installation method is currently deprecated and will
|
||||
not be supported at some point in the future.
|
||||
|
||||
On Windows systems you are encouraged to install and use the
|
||||
[Powershell](https://learn.microsoft.com/en-us/powershell/scripting/install/installing-powershell-on-windows?view=powershell-7.3),
|
||||
On Windows systems, you are encouraged to install and use the
|
||||
[PowerShell](https://learn.microsoft.com/en-us/powershell/scripting/install/installing-powershell-on-windows?view=powershell-7.3),
|
||||
which provides compatibility with Linux and Mac shells and nice
|
||||
features such as command-line completion.
|
||||
|
||||
@ -37,7 +38,7 @@ manager, please follow these steps:
|
||||
```bash
|
||||
python -V
|
||||
```
|
||||
|
||||
|
||||
2. Clone the [InvokeAI](https://github.com/invoke-ai/InvokeAI) source code from
|
||||
GitHub:
|
||||
|
||||
@ -52,15 +53,15 @@ manager, please follow these steps:
|
||||
environment named `invokeai`:
|
||||
|
||||
```bash
|
||||
python -mvenv invokeai
|
||||
python -m venv invokeai
|
||||
source invokeai/bin/activate
|
||||
```
|
||||
|
||||
4. Make sure that pip is installed in your virtual environment an up to date:
|
||||
4. Make sure that pip is installed in your virtual environment an up to date:
|
||||
|
||||
```bash
|
||||
python -mensurepip --upgrade
|
||||
python -mpip install --upgrade pip
|
||||
python -m ensurepip --upgrade
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
5. Pick the correct `requirements*.txt` file for your hardware and operating
|
||||
@ -199,20 +200,20 @@ manager, please follow these steps:
|
||||
|
||||
You can permanently set the location of the runtime directory by setting the environment variable INVOKEAI_ROOT to the path of the directory.
|
||||
|
||||
9. Render away!
|
||||
9. Render away!
|
||||
|
||||
Browse the [features](../features/CLI.md) section to learn about all the things you
|
||||
can do with InvokeAI.
|
||||
|
||||
Note that some GPUs are slow to warm up. In particular, when using an AMD
|
||||
card with the ROCm driver, you may have to wait for over a minute the first
|
||||
time you try to generate an image. Fortunately, after the warm up period
|
||||
time you try to generate an image. Fortunately, after the warm-up period
|
||||
rendering will be fast.
|
||||
|
||||
10. Subsequently, to relaunch the script, be sure to run "conda activate
|
||||
invokeai", enter the `InvokeAI` directory, and then launch the invoke
|
||||
script. If you forget to activate the 'invokeai' environment, the script
|
||||
will fail with multiple `ModuleNotFound` errors.
|
||||
10. Subsequently, to relaunch the script, be sure to enter `InvokeAI` directory,
|
||||
activate the virtual environment, and then launch `invoke.py` script.
|
||||
If you forget to activate the virtual environment,
|
||||
the script will fail with multiple `ModuleNotFound` errors.
|
||||
|
||||
!!! tip
|
||||
|
||||
|
@ -28,13 +28,18 @@ dependencies:
|
||||
- torch-fidelity=0.3.0
|
||||
- torchmetrics=0.7.0
|
||||
- torchvision
|
||||
- transformers=4.21.3
|
||||
- transformers~=4.25
|
||||
- pip:
|
||||
- accelerate
|
||||
- diffusers[torch]~=0.11
|
||||
- getpass_asterisk
|
||||
- huggingface-hub>=0.11.1
|
||||
- omegaconf==2.1.1
|
||||
- picklescan
|
||||
- pyreadline3
|
||||
- realesrgan
|
||||
- requests==2.25.1
|
||||
- safetensors
|
||||
- taming-transformers-rom1504
|
||||
- test-tube>=0.7.5
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
|
@ -9,14 +9,16 @@ dependencies:
|
||||
- numpy=1.23.3
|
||||
- pip:
|
||||
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/
|
||||
- accelerate
|
||||
- albumentations==0.4.3
|
||||
- diffusers==0.6.0
|
||||
- diffusers[torch]~=0.11
|
||||
- einops==0.3.0
|
||||
- eventlet
|
||||
- flask==2.1.3
|
||||
- flask_cors==3.0.10
|
||||
- flask_socketio==5.3.0
|
||||
- getpass_asterisk
|
||||
- huggingface-hub>=0.11.1
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- imageio==2.9.0
|
||||
- kornia==0.6.0
|
||||
@ -28,6 +30,8 @@ dependencies:
|
||||
- pyreadline3
|
||||
- pytorch-lightning==1.7.7
|
||||
- realesrgan
|
||||
- requests==2.25.1
|
||||
- safetensors
|
||||
- send2trash==1.8.0
|
||||
- streamlit==1.12.0
|
||||
- taming-transformers-rom1504
|
||||
@ -38,7 +42,7 @@ dependencies:
|
||||
- torchaudio
|
||||
- torchmetrics==0.7.0
|
||||
- torchvision
|
||||
- transformers==4.21.3
|
||||
- transformers~=4.25
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||
|
@ -12,14 +12,16 @@ dependencies:
|
||||
- pytorch=1.12.1
|
||||
- cudatoolkit=11.6
|
||||
- pip:
|
||||
- accelerate~=0.13
|
||||
- albumentations==0.4.3
|
||||
- diffusers==0.6.0
|
||||
- diffusers[torch]~=0.11
|
||||
- einops==0.3.0
|
||||
- eventlet
|
||||
- flask==2.1.3
|
||||
- flask_cors==3.0.10
|
||||
- flask_socketio==5.3.0
|
||||
- getpass_asterisk
|
||||
- huggingface-hub>=0.11.1
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- imageio==2.9.0
|
||||
- kornia==0.6.0
|
||||
@ -31,13 +33,15 @@ dependencies:
|
||||
- pyreadline3
|
||||
- pytorch-lightning==1.7.7
|
||||
- realesrgan
|
||||
- requests==2.25.1
|
||||
- safetensors~=0.2
|
||||
- send2trash==1.8.0
|
||||
- streamlit==1.12.0
|
||||
- taming-transformers-rom1504
|
||||
- test-tube>=0.7.5
|
||||
- torch-fidelity==0.3.0
|
||||
- torchmetrics==0.7.0
|
||||
- transformers==4.21.3
|
||||
- transformers~=4.25
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||
|
@ -1,6 +1,7 @@
|
||||
name: invokeai
|
||||
channels:
|
||||
- pytorch
|
||||
- huggingface
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
@ -19,10 +20,9 @@ dependencies:
|
||||
# sed -E 's/invokeai/invokeai-updated/;20,99s/- ([^=]+)==.+/- \1/' environment-mac.yml > environment-mac-updated.yml
|
||||
# CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac-updated.yml && conda list -n invokeai-updated | awk ' {print " - " $1 "==" $2;} '
|
||||
# ```
|
||||
|
||||
- accelerate
|
||||
- albumentations=1.2
|
||||
- coloredlogs=15.0
|
||||
- diffusers=0.6
|
||||
- einops=0.3
|
||||
- eventlet
|
||||
- grpcio=1.46
|
||||
@ -49,10 +49,14 @@ dependencies:
|
||||
- sympy=1.10
|
||||
- send2trash=1.8
|
||||
- tensorboard=2.10
|
||||
- transformers=4.23
|
||||
- transformers~=4.25
|
||||
- pip:
|
||||
- diffusers[torch]~=0.11
|
||||
- safetensors~=0.2
|
||||
- getpass_asterisk
|
||||
- huggingface-hub
|
||||
- picklescan
|
||||
- requests==2.25.1
|
||||
- taming-transformers-rom1504
|
||||
- test-tube==0.7.5
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
|
@ -12,14 +12,16 @@ dependencies:
|
||||
- pytorch=1.12.1
|
||||
- cudatoolkit=11.6
|
||||
- pip:
|
||||
- accelerate
|
||||
- albumentations==0.4.3
|
||||
- diffusers==0.6.0
|
||||
- diffusers[torch]~=0.11
|
||||
- einops==0.3.0
|
||||
- eventlet
|
||||
- flask==2.1.3
|
||||
- flask_cors==3.0.10
|
||||
- flask_socketio==5.3.0
|
||||
- getpass_asterisk
|
||||
- huggingface-hub>=0.11.1
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- imageio==2.9.0
|
||||
- kornia==0.6.0
|
||||
@ -31,13 +33,16 @@ dependencies:
|
||||
- pyreadline3
|
||||
- pytorch-lightning==1.7.7
|
||||
- realesrgan
|
||||
- requests==2.25.1
|
||||
- safetensors
|
||||
- send2trash==1.8.0
|
||||
- streamlit==1.12.0
|
||||
- taming-transformers-rom1504
|
||||
- test-tube>=0.7.5
|
||||
- torch-fidelity==0.3.0
|
||||
- torchmetrics==0.7.0
|
||||
- transformers==4.21.3
|
||||
- transformers~=4.25
|
||||
- windows-curses
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||
|
@ -1,6 +1,9 @@
|
||||
# pip will resolve the version which matches torch
|
||||
accelerate
|
||||
albumentations
|
||||
diffusers==0.10.*
|
||||
datasets
|
||||
diffusers[torch]~=0.11
|
||||
dnspython==2.2.1
|
||||
einops
|
||||
eventlet
|
||||
facexlib
|
||||
@ -14,6 +17,7 @@ huggingface-hub>=0.11.1
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
kornia
|
||||
npyscreen
|
||||
numpy==1.23.*
|
||||
omegaconf
|
||||
opencv-python
|
||||
@ -25,6 +29,7 @@ pyreadline3
|
||||
pytorch-lightning==1.7.7
|
||||
realesrgan
|
||||
requests==2.25.1
|
||||
safetensors
|
||||
scikit-image>=0.19
|
||||
send2trash
|
||||
streamlit
|
||||
@ -32,7 +37,8 @@ taming-transformers-rom1504
|
||||
test-tube>=0.7.5
|
||||
torch-fidelity
|
||||
torchmetrics
|
||||
transformers==4.25.*
|
||||
transformers~=4.25
|
||||
windows-curses; sys_platform == 'win32'
|
||||
https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip#egg=k-diffusion
|
||||
https://github.com/invoke-ai/PyPatchMatch/archive/refs/tags/0.1.5.zip#egg=pypatchmatch
|
||||
https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip#egg=clip
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
4
frontend/dist/index.html
vendored
4
frontend/dist/index.html
vendored
@ -7,7 +7,7 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
|
||||
<script type="module" crossorigin src="./assets/index.ec2d89c6.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index.1b59e83a.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index.0dadf5d0.css">
|
||||
<script type="module">try{import.meta.url;import("_").catch(()=>1);}catch(e){}window.__vite_is_modern_browser=true;</script>
|
||||
<script type="module">!function(){if(window.__vite_is_modern_browser)return;console.warn("vite: loading legacy build because dynamic import or import.meta.url is unsupported, syntax error above should be ignored");var e=document.getElementById("vite-legacy-polyfill"),n=document.createElement("script");n.src=e.src,n.onload=function(){System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))},document.body.appendChild(n)}();</script>
|
||||
@ -18,6 +18,6 @@
|
||||
|
||||
<script nomodule>!function(){var e=document,t=e.createElement("script");if(!("noModule"in t)&&"onbeforeload"in t){var n=!1;e.addEventListener("beforeload",(function(e){if(e.target===t)n=!0;else if(!e.target.hasAttribute("nomodule")||!n)return;e.preventDefault()}),!0),t.type="module",t.src=".",e.head.appendChild(t),t.remove()}}();</script>
|
||||
<script nomodule crossorigin id="vite-legacy-polyfill" src="./assets/polyfills-legacy-dde3a68a.js"></script>
|
||||
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-5c5a479d.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
|
||||
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-474a75fe.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
|
||||
</body>
|
||||
</html>
|
||||
|
@ -17,6 +17,8 @@
|
||||
"langPortuguese": "Portuguese",
|
||||
"langFrench": "French",
|
||||
"langPolish": "Polish",
|
||||
"langSimplifiedChinese": "Simplified Chinese",
|
||||
"langSpanish": "Spanish",
|
||||
"text2img": "Text To Image",
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
@ -32,6 +34,7 @@
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
"statusDisconnected": "Disconnected",
|
||||
"statusError": "Error",
|
||||
|
@ -34,6 +34,7 @@
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
"statusDisconnected": "Disconnected",
|
||||
"statusError": "Error",
|
||||
|
@ -1,12 +1,18 @@
|
||||
{
|
||||
"modelManager": "Model Manager",
|
||||
"model": "Model",
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
"safetensorModels": "SafeTensors",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelEntryDeleted": "Model Entry Deleted",
|
||||
"cannotUseSpaces": "Cannot Use Spaces",
|
||||
"addNew": "Add New",
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
@ -17,8 +23,12 @@
|
||||
"configValidationMsg": "Path to the config file of your model.",
|
||||
"modelLocation": "Model Location",
|
||||
"modelLocationValidationMsg": "Path to where your model is located.",
|
||||
"repo_id": "Repo ID",
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||
"vaeRepoID": "VAE Repo ID",
|
||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||
"width": "Width",
|
||||
"widthValidationMsg": "Default width of your model.",
|
||||
"height": "Height",
|
||||
@ -34,6 +44,7 @@
|
||||
"checkpointFolder": "Checkpoint Folder",
|
||||
"clearCheckpointFolder": "Clear Checkpoint Folder",
|
||||
"findModels": "Find Models",
|
||||
"scanAgain": "Scan Again",
|
||||
"modelsFound": "Models Found",
|
||||
"selectFolder": "Select Folder",
|
||||
"selected": "Selected",
|
||||
@ -42,9 +53,15 @@
|
||||
"showExisting": "Show Existing",
|
||||
"addSelected": "Add Selected",
|
||||
"modelExists": "Model Exists",
|
||||
"selectAndAdd": "Select and Add Models Listed Below",
|
||||
"noModelsFound": "No Models Found",
|
||||
"delete": "Delete",
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to."
|
||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
|
||||
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||
"formMessageDiffusersVAELocation": "VAE Location",
|
||||
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above."
|
||||
}
|
||||
|
@ -1,12 +1,18 @@
|
||||
{
|
||||
"modelManager": "Model Manager",
|
||||
"model": "Model",
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
"safetensorModels": "SafeTensors",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelEntryDeleted": "Model Entry Deleted",
|
||||
"cannotUseSpaces": "Cannot Use Spaces",
|
||||
"addNew": "Add New",
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
@ -17,8 +23,12 @@
|
||||
"configValidationMsg": "Path to the config file of your model.",
|
||||
"modelLocation": "Model Location",
|
||||
"modelLocationValidationMsg": "Path to where your model is located.",
|
||||
"repo_id": "Repo ID",
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||
"vaeRepoID": "VAE Repo ID",
|
||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||
"width": "Width",
|
||||
"widthValidationMsg": "Default width of your model.",
|
||||
"height": "Height",
|
||||
@ -49,5 +59,9 @@
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to."
|
||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
|
||||
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||
"formMessageDiffusersVAELocation": "VAE Location",
|
||||
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above."
|
||||
}
|
||||
|
@ -1 +1,15 @@
|
||||
{}
|
||||
{
|
||||
"feature": {
|
||||
"prompt": "Questo è il campo del prompt. Il prompt include oggetti di generazione e termini stilistici. Puoi anche aggiungere il peso (importanza del token) nel prompt, ma i comandi e i parametri dell'interfaccia a linea di comando non funzioneranno.",
|
||||
"gallery": "Galleria visualizza le generazioni dalla cartella degli output man mano che vengono create. Le impostazioni sono memorizzate all'interno di file e accessibili dal menu contestuale.",
|
||||
"other": "Queste opzioni abiliteranno modalità di elaborazione alternative per Invoke. 'Piastrella senza cuciture' creerà modelli ripetuti nell'output. 'Ottimizzzazione Alta risoluzione' è la generazione in due passaggi con 'Immagine a Immagine': usa questa impostazione quando vuoi un'immagine più grande e più coerente senza artefatti. Ci vorrà più tempo del solito 'Testo a Immagine'.",
|
||||
"seed": "Il valore del Seme influenza il rumore iniziale da cui è formata l'immagine. Puoi usare i semi già esistenti dalle immagini precedenti. 'Soglia del rumore' viene utilizzato per mitigare gli artefatti a valori CFG elevati (provare l'intervallo 0-10) e Perlin per aggiungere il rumore Perlin durante la generazione: entrambi servono per aggiungere variazioni ai risultati.",
|
||||
"variations": "Prova una variazione con un valore compreso tra 0.1 e 1.0 per modificare il risultato per un dato seme. Variazioni interessanti del seme sono comprese tra 0.1 e 0.3.",
|
||||
"upscale": "Utilizza ESRGAN per ingrandire l'immagine subito dopo la generazione.",
|
||||
"faceCorrection": "Correzione del volto con GFPGAN o Codeformer: l'algoritmo rileva i volti nell'immagine e corregge eventuali difetti. Un valore alto cambierà maggiormente l'immagine, dando luogo a volti più attraenti. Codeformer con una maggiore fedeltà preserva l'immagine originale a scapito di una correzione facciale più forte.",
|
||||
"imageToImage": "Da Immagine a Immagine carica qualsiasi immagine come iniziale, che viene quindi utilizzata per generarne una nuova in base al prompt. Più alto è il valore, più cambierà l'immagine risultante. Sono possibili valori da 0.0 a 1.0, l'intervallo consigliato è 0.25-0.75",
|
||||
"boundingBox": "Il riquadro di selezione è lo stesso delle impostazioni Larghezza e Altezza per da Testo a Immagine o da Immagine a Immagine. Verrà elaborata solo l'area nella casella.",
|
||||
"seamCorrection": "Controlla la gestione delle giunzioni visibili che si verificano tra le immagini generate sulla tela.",
|
||||
"infillAndScaling": "Gestisce i metodi di riempimento (utilizzati su aree mascherate o cancellate dell'area di disegno) e il ridimensionamento (utile per i riquadri di selezione di piccole dimensioni)."
|
||||
}
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
{
|
||||
"feature": {
|
||||
"prompt": "Questo è il campo del prompt. Il prompt include oggetti di generazione e termini stilistici. Puoi anche aggiungere il peso (importanza del token) nel prompt, ma i comandi e i parametri dell'interfaccia a linea di comando non funzioneranno.",
|
||||
"gallery": "Galleria visualizza le generazioni dalla cartella degli output man mano che vengono create. Le impostazioni sono memorizzate all'interno di file e accessibili dal menu contestuale.",
|
||||
"other": "Queste opzioni abiliteranno modalità di elaborazione alternative per Invoke. 'Piastrella senza cuciture' creerà modelli ripetuti nell'output. 'Ottimizzzazione Alta risoluzione' è la generazione in due passaggi con 'Immagine a Immagine': usa questa impostazione quando vuoi un'immagine più grande e più coerente senza artefatti. Ci vorrà più tempo del solito 'Testo a Immagine'.",
|
||||
"seed": "Il valore del Seme influenza il rumore iniziale da cui è formata l'immagine. Puoi usare i semi già esistenti dalle immagini precedenti. 'Soglia del rumore' viene utilizzato per mitigare gli artefatti a valori CFG elevati (provare l'intervallo 0-10) e Perlin per aggiungere il rumore Perlin durante la generazione: entrambi servono per aggiungere variazioni ai risultati.",
|
||||
"variations": "Prova una variazione con un valore compreso tra 0.1 e 1.0 per modificare il risultato per un dato seme. Variazioni interessanti del seme sono comprese tra 0.1 e 0.3.",
|
||||
"upscale": "Utilizza ESRGAN per ingrandire l'immagine subito dopo la generazione.",
|
||||
"faceCorrection": "Correzione del volto con GFPGAN o Codeformer: l'algoritmo rileva i volti nell'immagine e corregge eventuali difetti. Un valore alto cambierà maggiormente l'immagine, dando luogo a volti più attraenti. Codeformer con una maggiore fedeltà preserva l'immagine originale a scapito di una correzione facciale più forte.",
|
||||
"imageToImage": "Da Immagine a Immagine carica qualsiasi immagine come iniziale, che viene quindi utilizzata per generarne una nuova in base al prompt. Più alto è il valore, più cambierà l'immagine risultante. Sono possibili valori da 0.0 a 1.0, l'intervallo consigliato è 0.25-0.75",
|
||||
"boundingBox": "Il riquadro di selezione è lo stesso delle impostazioni Larghezza e Altezza per dat Testo a Immagine o da Immagine a Immagine. Verrà elaborata solo l'area nella casella.",
|
||||
"seamCorrection": "Controlla la gestione delle giunzioni visibili che si verificano tra le immagini generate sulla tela.",
|
||||
"infillAndScaling": "Gestisce i metodi di riempimento (utilizzati su aree mascherate o cancellate dell'area di disegno) e il ridimensionamento (utile per i riquadri di selezione di piccole dimensioni)."
|
||||
}
|
||||
}
|
30
frontend/src/app/invokeai.d.ts
vendored
30
frontend/src/app/invokeai.d.ts
vendored
@ -170,9 +170,23 @@ export declare type Model = {
|
||||
width?: number;
|
||||
height?: number;
|
||||
default?: boolean;
|
||||
format?: string;
|
||||
};
|
||||
|
||||
export declare type ModelList = Record<string, Model>;
|
||||
export declare type DiffusersModel = {
|
||||
status: ModelStatus;
|
||||
description: string;
|
||||
repo_id?: string;
|
||||
path?: string;
|
||||
vae?: {
|
||||
repo_id?: string;
|
||||
path?: string;
|
||||
};
|
||||
format?: string;
|
||||
default?: boolean;
|
||||
};
|
||||
|
||||
export declare type ModelList = Record<string, Model & DiffusersModel>;
|
||||
|
||||
export declare type FoundModel = {
|
||||
name: string;
|
||||
@ -188,6 +202,20 @@ export declare type InvokeModelConfigProps = {
|
||||
width: number | undefined;
|
||||
height: number | undefined;
|
||||
default: boolean | undefined;
|
||||
format: string | undefined;
|
||||
};
|
||||
|
||||
export declare type InvokeDiffusersModelConfigProps = {
|
||||
name: string | undefined;
|
||||
description: string | undefined;
|
||||
repo_id: string | undefined;
|
||||
path: string | undefined;
|
||||
default: boolean | undefined;
|
||||
format: string | undefined;
|
||||
vae: {
|
||||
repo_id: string | undefined;
|
||||
path: string | undefined;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -32,9 +32,9 @@ export const requestSystemConfig = createAction<undefined>(
|
||||
|
||||
export const searchForModels = createAction<string>('socketio/searchForModels');
|
||||
|
||||
export const addNewModel = createAction<InvokeAI.InvokeModelConfigProps>(
|
||||
'socketio/addNewModel'
|
||||
);
|
||||
export const addNewModel = createAction<
|
||||
InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
|
||||
>('socketio/addNewModel');
|
||||
|
||||
export const deleteModel = createAction<string>('socketio/deleteModel');
|
||||
|
||||
|
@ -22,16 +22,16 @@ const layerToDataURL = (
|
||||
const { x, y, width, height } = layer.getClientRect();
|
||||
const dataURLBoundingBox = boundingBox
|
||||
? {
|
||||
x: Math.round(boundingBox.x + stageCoordinates.x),
|
||||
y: Math.round(boundingBox.y + stageCoordinates.y),
|
||||
width: Math.round(boundingBox.width),
|
||||
height: Math.round(boundingBox.height),
|
||||
x: boundingBox.x + stageCoordinates.x,
|
||||
y: boundingBox.y + stageCoordinates.y,
|
||||
width: boundingBox.width,
|
||||
height: boundingBox.height,
|
||||
}
|
||||
: {
|
||||
x: Math.round(x),
|
||||
y: Math.round(y),
|
||||
width: Math.round(width),
|
||||
height: Math.round(height),
|
||||
x: x,
|
||||
y: y,
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
|
||||
const dataURL = layer.toDataURL(dataURLBoundingBox);
|
||||
@ -42,10 +42,10 @@ const layerToDataURL = (
|
||||
return {
|
||||
dataURL,
|
||||
boundingBox: {
|
||||
x: Math.round(relativeClientRect.x),
|
||||
y: Math.round(relativeClientRect.y),
|
||||
width: Math.round(width),
|
||||
height: Math.round(height),
|
||||
x: relativeClientRect.x,
|
||||
y: relativeClientRect.y,
|
||||
width: width,
|
||||
height: height,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
@ -57,6 +57,7 @@ export interface OptionsState {
|
||||
width: number;
|
||||
shouldUseCanvasBetaLayout: boolean;
|
||||
shouldShowExistingModelsInSearch: boolean;
|
||||
addNewModelUIOption: 'ckpt' | 'diffusers' | null;
|
||||
}
|
||||
|
||||
const initialOptionsState: OptionsState = {
|
||||
@ -105,6 +106,7 @@ const initialOptionsState: OptionsState = {
|
||||
width: 512,
|
||||
shouldUseCanvasBetaLayout: false,
|
||||
shouldShowExistingModelsInSearch: false,
|
||||
addNewModelUIOption: null,
|
||||
};
|
||||
|
||||
const initialState: OptionsState = initialOptionsState;
|
||||
@ -412,6 +414,12 @@ export const optionsSlice = createSlice({
|
||||
) => {
|
||||
state.shouldShowExistingModelsInSearch = action.payload;
|
||||
},
|
||||
setAddNewModelUIOption: (
|
||||
state,
|
||||
action: PayloadAction<'ckpt' | 'diffusers' | null>
|
||||
) => {
|
||||
state.addNewModelUIOption = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@ -469,6 +477,7 @@ export const {
|
||||
setWidth,
|
||||
setShouldUseCanvasBetaLayout,
|
||||
setShouldShowExistingModelsInSearch,
|
||||
setAddNewModelUIOption,
|
||||
} = optionsSlice.actions;
|
||||
|
||||
export default optionsSlice.reducer;
|
||||
|
@ -3,6 +3,7 @@
|
||||
.modal {
|
||||
background-color: var(--background-color-secondary);
|
||||
color: var(--text-color);
|
||||
font-family: Inter;
|
||||
}
|
||||
|
||||
.modal-close-btn {
|
||||
|
@ -0,0 +1,328 @@
|
||||
import {
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import React from 'react';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAICheckbox from 'common/components/IAICheckbox';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import SearchModels from './SearchModels';
|
||||
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import type { RootState } from 'app/store';
|
||||
import type { InvokeModelConfigProps } from 'app/invokeai';
|
||||
import { setAddNewModelUIOption } from 'features/options/store/optionsSlice';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
|
||||
export default function AddCheckpointModel() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
function hasWhiteSpace(s: string) {
|
||||
return /\s/.test(s);
|
||||
}
|
||||
|
||||
function baseValidation(value: string) {
|
||||
let error;
|
||||
if (hasWhiteSpace(value)) error = t('modelmanager:cannotUseSpaces');
|
||||
return error;
|
||||
}
|
||||
|
||||
const addModelFormValues: InvokeModelConfigProps = {
|
||||
name: '',
|
||||
description: '',
|
||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
||||
weights: '',
|
||||
vae: '',
|
||||
width: 512,
|
||||
height: 512,
|
||||
format: 'ckpt',
|
||||
default: false,
|
||||
};
|
||||
|
||||
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||
dispatch(addNewModel(values));
|
||||
dispatch(setAddNewModelUIOption(null));
|
||||
};
|
||||
|
||||
const [addManually, setAddmanually] = React.useState<boolean>(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAIIconButton
|
||||
aria-label={t('common:back')}
|
||||
tooltip={t('common:back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
right={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
|
||||
<SearchModels />
|
||||
<IAICheckbox
|
||||
label={t('modelmanager:addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
|
||||
{addManually && (
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack rowGap={'0.5rem'}>
|
||||
<Text fontSize={20} fontWeight="bold" alignSelf={'start'}>
|
||||
{t('modelmanager:manual')}
|
||||
</Text>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelmanager:name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelmanager:description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Config */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelmanager:config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Weights */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelmanager:modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE */}
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelmanager:vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
<HStack width={'100%'}>
|
||||
{/* Width */}
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelmanager:width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
width="90%"
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Height */}
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelmanager:height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
width="90%"
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</HStack>
|
||||
|
||||
<IAIButton
|
||||
type="submit"
|
||||
className="modal-close-btn"
|
||||
isLoading={isProcessing}
|
||||
>
|
||||
{t('modelmanager:addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
)}
|
||||
</Formik>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
@ -0,0 +1,310 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { setAddNewModelUIOption } from 'features/options/store/optionsSlice';
|
||||
import { Field, Formik } from 'formik';
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
|
||||
import type { RootState } from 'app/store';
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
function FormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
backgroundColor="var(--background-color)"
|
||||
padding="1rem 1rem"
|
||||
borderRadius="0.5rem"
|
||||
rowGap="1rem"
|
||||
width="100%"
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
export default function AddDiffusersModel() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
function hasWhiteSpace(s: string) {
|
||||
return /\s/.test(s);
|
||||
}
|
||||
|
||||
function baseValidation(value: string) {
|
||||
let error;
|
||||
if (hasWhiteSpace(value)) error = t('modelmanager:cannotUseSpaces');
|
||||
return error;
|
||||
}
|
||||
|
||||
const addModelFormValues: InvokeDiffusersModelConfigProps = {
|
||||
name: '',
|
||||
description: '',
|
||||
repo_id: '',
|
||||
path: '',
|
||||
format: 'diffusers',
|
||||
default: false,
|
||||
vae: {
|
||||
repo_id: '',
|
||||
path: '',
|
||||
},
|
||||
};
|
||||
|
||||
const addModelFormSubmitHandler = (
|
||||
values: InvokeDiffusersModelConfigProps
|
||||
) => {
|
||||
const diffusersModelToAdd = values;
|
||||
|
||||
if (values.path === '') diffusersModelToAdd['path'] = undefined;
|
||||
if (values.repo_id === '') diffusersModelToAdd['repo_id'] = undefined;
|
||||
if (values.vae.path === '') {
|
||||
if (values.path === undefined) {
|
||||
diffusersModelToAdd['vae']['path'] = undefined;
|
||||
} else {
|
||||
diffusersModelToAdd['vae']['path'] = values.path + '/vae';
|
||||
}
|
||||
}
|
||||
if (values.vae.repo_id === '')
|
||||
diffusersModelToAdd['vae']['repo_id'] = undefined;
|
||||
|
||||
dispatch(addNewModel(diffusersModelToAdd));
|
||||
dispatch(setAddNewModelUIOption(null));
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
<IAIIconButton
|
||||
aria-label={t('common:back')}
|
||||
tooltip={t('common:back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
right={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack rowGap={'0.5rem'}>
|
||||
<FormItemWrapper>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelmanager:name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="2xl"
|
||||
isRequired
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelmanager:description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="2xl"
|
||||
isRequired
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{t('modelmanager:formMessageDiffusersModelLocation')}
|
||||
</Text>
|
||||
<Text
|
||||
fontSize="sm"
|
||||
fontStyle="italic"
|
||||
color="var(--text-color-secondary)"
|
||||
>
|
||||
{t('modelmanager:formMessageDiffusersModelLocationDesc')}
|
||||
</Text>
|
||||
|
||||
{/* Path */}
|
||||
<FormControl isInvalid={!!errors.path && touched.path}>
|
||||
<FormLabel htmlFor="path" fontSize="sm">
|
||||
{t('modelmanager:modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="path"
|
||||
name="path"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.path && touched.path ? (
|
||||
<FormErrorMessage>{errors.path}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Repo ID */}
|
||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
||||
{t('modelmanager:repo_id')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="repo_id"
|
||||
name="repo_id"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.repo_id && touched.repo_id ? (
|
||||
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:repoIDValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
{/* VAE Path */}
|
||||
<Text fontWeight="bold">
|
||||
{t('modelmanager:formMessageDiffusersVAELocation')}
|
||||
</Text>
|
||||
<Text
|
||||
fontSize="sm"
|
||||
fontStyle="italic"
|
||||
color="var(--text-color-secondary)"
|
||||
>
|
||||
{t('modelmanager:formMessageDiffusersVAELocationDesc')}
|
||||
</Text>
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
||||
>
|
||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
||||
{t('modelmanager:vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.path"
|
||||
name="vae.path"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae?.path && touched.vae?.path ? (
|
||||
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE Repo ID */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
||||
>
|
||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
||||
{t('modelmanager:vaeRepoID')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.repo_id"
|
||||
name="vae.repo_id"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
||||
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:vaeRepoIDValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
|
||||
<IAIButton
|
||||
type="submit"
|
||||
className="modal-close-btn"
|
||||
isLoading={isProcessing}
|
||||
>
|
||||
{t('modelmanager:addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,10 +1,5 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
@ -13,72 +8,64 @@ import {
|
||||
ModalOverlay,
|
||||
Text,
|
||||
useDisclosure,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import React from 'react';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAICheckbox from 'common/components/IAICheckbox';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import SearchModels from './SearchModels';
|
||||
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { Field, Formik } from 'formik';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import type { RootState } from 'app/store';
|
||||
import type { InvokeModelConfigProps } from 'app/invokeai';
|
||||
import { setAddNewModelUIOption } from 'features/options/store/optionsSlice';
|
||||
import AddCheckpointModel from './AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddDiffusersModel';
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
function AddModelBox({
|
||||
text,
|
||||
onClick,
|
||||
}: {
|
||||
text: string;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
width="50%"
|
||||
height="200px"
|
||||
backgroundColor="var(--background-color)"
|
||||
borderRadius="0.5rem"
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
_hover={{
|
||||
cursor: 'pointer',
|
||||
backgroundColor: 'var(--accent-color)',
|
||||
}}
|
||||
onClick={onClick}
|
||||
>
|
||||
<Text fontWeight="bold">{text}</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
export default function AddModel() {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
const addNewModelUIOption = useAppSelector(
|
||||
(state: RootState) => state.options.addNewModelUIOption
|
||||
);
|
||||
|
||||
function hasWhiteSpace(s: string) {
|
||||
return /\\s/g.test(s);
|
||||
}
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
function baseValidation(value: string) {
|
||||
let error;
|
||||
if (hasWhiteSpace(value)) error = t('modelmanager:cannotUseSpaces');
|
||||
return error;
|
||||
}
|
||||
|
||||
const addModelFormValues: InvokeModelConfigProps = {
|
||||
name: '',
|
||||
description: '',
|
||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
||||
weights: '',
|
||||
vae: '',
|
||||
width: 512,
|
||||
height: 512,
|
||||
default: false,
|
||||
};
|
||||
|
||||
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||
dispatch(addNewModel(values));
|
||||
onClose();
|
||||
};
|
||||
const { t } = useTranslation();
|
||||
|
||||
const addModelModalClose = () => {
|
||||
onClose();
|
||||
dispatch(setAddNewModelUIOption(null));
|
||||
};
|
||||
|
||||
const [addManually, setAddmanually] = React.useState<boolean>(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAIButton
|
||||
@ -101,266 +88,24 @@ export default function AddModel() {
|
||||
closeOnOverlayClick={false}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent className="modal add-model-modal">
|
||||
<ModalContent className="modal add-model-modal" fontFamily="Inter">
|
||||
<ModalHeader>{t('modelmanager:addNewModel')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalCloseButton marginTop="0.3rem" />
|
||||
<ModalBody className="add-model-modal-body">
|
||||
<SearchModels />
|
||||
<IAICheckbox
|
||||
label={t('modelmanager:addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
|
||||
{addManually && (
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack rowGap={'0.5rem'}>
|
||||
<Text fontSize={20} fontWeight="bold" alignSelf={'start'}>
|
||||
{t('modelmanager:manual')}
|
||||
</Text>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelmanager:name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelmanager:description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>
|
||||
{errors.description}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Config */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelmanager:config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Weights */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelmanager:modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>
|
||||
{errors.weights}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE */}
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelmanager:vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
<HStack width={'100%'}>
|
||||
{/* Width */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.width && touched.width}
|
||||
>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelmanager:width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
width="90%"
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(
|
||||
field.name,
|
||||
Number(value)
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>
|
||||
{errors.width}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Height */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.height && touched.height}
|
||||
>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelmanager:height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
width="90%"
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(
|
||||
field.name,
|
||||
Number(value)
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>
|
||||
{errors.height}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</HStack>
|
||||
|
||||
<IAIButton
|
||||
type="submit"
|
||||
className="modal-close-btn"
|
||||
isLoading={isProcessing}
|
||||
>
|
||||
{t('modelmanager:addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
)}
|
||||
</Formik>
|
||||
{addNewModelUIOption == null && (
|
||||
<Flex columnGap="1rem">
|
||||
<AddModelBox
|
||||
text={t('modelmanager:addCheckpointModel')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
||||
/>
|
||||
<AddModelBox
|
||||
text={t('modelmanager:addDiffuserModel')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
|
@ -48,7 +48,7 @@ const selector = createSelector(
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
|
||||
export default function ModelEdit() {
|
||||
export default function CheckpointModelEdit() {
|
||||
const { openModel, model_list } = useAppSelector(selector);
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
@ -68,6 +68,7 @@ export default function ModelEdit() {
|
||||
width: 512,
|
||||
height: 512,
|
||||
default: false,
|
||||
format: 'ckpt',
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
@ -84,12 +85,19 @@ export default function ModelEdit() {
|
||||
width: retrievedModel[openModel]?.width,
|
||||
height: retrievedModel[openModel]?.height,
|
||||
default: retrievedModel[openModel]?.default,
|
||||
format: 'ckpt',
|
||||
});
|
||||
}
|
||||
}, [model_list, openModel]);
|
||||
|
||||
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||
dispatch(addNewModel(values));
|
||||
dispatch(
|
||||
addNewModel({
|
||||
...values,
|
||||
width: Number(values.width),
|
||||
height: Number(values.height),
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return openModel ? (
|
@ -0,0 +1,270 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
|
||||
import _ from 'lodash';
|
||||
|
||||
import type { RootState } from 'app/store';
|
||||
import type { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
(system) => {
|
||||
const { openModel, model_list } = system;
|
||||
return {
|
||||
model_list,
|
||||
openModel,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: _.isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export default function DiffusersModelEdit() {
|
||||
const { openModel, model_list } = useAppSelector(selector);
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [editModelFormValues, setEditModelFormValues] =
|
||||
useState<InvokeDiffusersModelConfigProps>({
|
||||
name: '',
|
||||
description: '',
|
||||
repo_id: '',
|
||||
path: '',
|
||||
vae: { repo_id: '', path: '' },
|
||||
default: false,
|
||||
format: 'diffusers',
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (openModel) {
|
||||
const retrievedModel = _.pickBy(model_list, (val, key) => {
|
||||
return _.isEqual(key, openModel);
|
||||
});
|
||||
|
||||
setEditModelFormValues({
|
||||
name: openModel,
|
||||
description: retrievedModel[openModel]?.description,
|
||||
path: retrievedModel[openModel]?.path,
|
||||
repo_id: retrievedModel[openModel]?.repo_id,
|
||||
vae: {
|
||||
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
||||
? retrievedModel[openModel]?.vae?.repo_id
|
||||
: '',
|
||||
path: retrievedModel[openModel]?.vae?.path
|
||||
? retrievedModel[openModel]?.vae?.path
|
||||
: '',
|
||||
},
|
||||
default: retrievedModel[openModel]?.default,
|
||||
format: 'diffusers',
|
||||
});
|
||||
}
|
||||
}, [model_list, openModel]);
|
||||
|
||||
const editModelFormSubmitHandler = (
|
||||
values: InvokeDiffusersModelConfigProps
|
||||
) => {
|
||||
dispatch(addNewModel(values));
|
||||
};
|
||||
|
||||
return openModel ? (
|
||||
<Flex flexDirection="column" rowGap="1rem" width="100%">
|
||||
<Flex alignItems="center">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{openModel}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
maxHeight={window.innerHeight - 270}
|
||||
overflowY="scroll"
|
||||
paddingRight="2rem"
|
||||
>
|
||||
<Formik
|
||||
enableReinitialize={true}
|
||||
initialValues={editModelFormValues}
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack rowGap={'0.5rem'} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelmanager:description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="lg"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Path */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.path && touched.path}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="path" fontSize="sm">
|
||||
{t('modelmanager:modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="path"
|
||||
name="path"
|
||||
type="text"
|
||||
width="lg"
|
||||
/>
|
||||
{!!errors.path && touched.path ? (
|
||||
<FormErrorMessage>{errors.path}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Repo ID */}
|
||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
||||
{t('modelmanager:repo_id')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="repo_id"
|
||||
name="repo_id"
|
||||
type="text"
|
||||
width="lg"
|
||||
/>
|
||||
{!!errors.repo_id && touched.repo_id ? (
|
||||
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:repoIDValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE Path */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
||||
>
|
||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
||||
{t('modelmanager:vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.path"
|
||||
name="vae.path"
|
||||
type="text"
|
||||
width="lg"
|
||||
/>
|
||||
{!!errors.vae?.path && touched.vae?.path ? (
|
||||
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE Repo ID */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
||||
>
|
||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
||||
{t('modelmanager:vaeRepoID')}
|
||||
</FormLabel>
|
||||
<VStack alignItems={'start'}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.repo_id"
|
||||
name="vae.repo_id"
|
||||
type="text"
|
||||
width="lg"
|
||||
/>
|
||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
||||
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelmanager:vaeRepoIDValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
<IAIButton
|
||||
type="submit"
|
||||
className="modal-close-btn"
|
||||
isLoading={isProcessing}
|
||||
>
|
||||
{t('modelmanager:updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
width="100%"
|
||||
height="250px"
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
backgroundColor="var(--background-color)"
|
||||
borderRadius="0.5rem"
|
||||
>
|
||||
<Text fontWeight="bold" color="var(--subtext-color-bright)">
|
||||
Pick A Model To Edit
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
import { useState } from 'react';
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import React, { useState, useTransition, useMemo } from 'react';
|
||||
import { Box, Flex, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
|
||||
@ -14,6 +14,7 @@ import _ from 'lodash';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import type { RootState } from 'app/store';
|
||||
import type { SystemState } from 'features/system/store/systemSlice';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
const modelListSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
@ -21,33 +22,64 @@ const modelListSelector = createSelector(
|
||||
const models = _.map(system.model_list, (model, key) => {
|
||||
return { name: key, ...model };
|
||||
});
|
||||
|
||||
const activeModel = models.find((model) => model.status === 'active');
|
||||
|
||||
return {
|
||||
models,
|
||||
activeModel: activeModel,
|
||||
};
|
||||
return models;
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: _.isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
function ModelFilterButton({
|
||||
label,
|
||||
isActive,
|
||||
onClick,
|
||||
}: {
|
||||
label: string;
|
||||
isActive: boolean;
|
||||
onClick: () => void;
|
||||
}) {
|
||||
return (
|
||||
<IAIButton
|
||||
onClick={onClick}
|
||||
isActive={isActive}
|
||||
_active={{
|
||||
backgroundColor: 'var(--accent-color)',
|
||||
_hover: { backgroundColor: 'var(--accent-color)' },
|
||||
}}
|
||||
size="sm"
|
||||
>
|
||||
{label}
|
||||
</IAIButton>
|
||||
);
|
||||
}
|
||||
|
||||
const ModelList = () => {
|
||||
const { models } = useAppSelector(modelListSelector);
|
||||
const models = useAppSelector(modelListSelector);
|
||||
|
||||
const [searchText, setSearchText] = useState<string>('');
|
||||
const [isSelectedFilter, setIsSelectedFilter] = useState<
|
||||
'all' | 'ckpt' | 'diffusers'
|
||||
>('all');
|
||||
const [_, startTransition] = useTransition();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleSearchFilter = _.debounce((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchText(e.target.value);
|
||||
}, 400);
|
||||
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
startTransition(() => {
|
||||
setSearchText(e.target.value);
|
||||
});
|
||||
};
|
||||
|
||||
const renderModelListItems = () => {
|
||||
const modelListItemsToRender: ReactNode[] = [];
|
||||
const renderModelListItems = useMemo(() => {
|
||||
const ckptModelListItemsToRender: ReactNode[] = [];
|
||||
const diffusersModelListItemsToRender: ReactNode[] = [];
|
||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||
|
||||
models.forEach((model, i) => {
|
||||
if (model.name.startsWith(searchText)) {
|
||||
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
|
||||
filteredModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
@ -56,21 +88,93 @@ const ModelList = () => {
|
||||
description={model.description}
|
||||
/>
|
||||
);
|
||||
if (model.format === isSelectedFilter) {
|
||||
localFilteredModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
if (model.format !== 'diffusers') {
|
||||
ckptModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
diffusersModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
/>
|
||||
);
|
||||
}
|
||||
modelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
return searchText !== ''
|
||||
? filteredModelListItemsToRender
|
||||
: modelListItemsToRender;
|
||||
};
|
||||
return searchText !== '' ? (
|
||||
isSelectedFilter === 'all' ? (
|
||||
<Box marginTop="1rem">{filteredModelListItemsToRender}</Box>
|
||||
) : (
|
||||
<Box marginTop="1rem">{localFilteredModelListItemsToRender}</Box>
|
||||
)
|
||||
) : (
|
||||
<Flex flexDirection="column" rowGap="1.5rem">
|
||||
{isSelectedFilter === 'all' && (
|
||||
<>
|
||||
<Box>
|
||||
<Text
|
||||
fontWeight="bold"
|
||||
backgroundColor="var(--background-color)"
|
||||
padding="0.5rem 1rem"
|
||||
borderRadius="0.5rem"
|
||||
margin="1rem 0"
|
||||
width="max-content"
|
||||
fontSize="14"
|
||||
>
|
||||
{t('modelmanager:checkpointModels')}
|
||||
</Text>
|
||||
{ckptModelListItemsToRender}
|
||||
</Box>
|
||||
<Box>
|
||||
<Text
|
||||
fontWeight="bold"
|
||||
backgroundColor="var(--background-color)"
|
||||
padding="0.5rem 1rem"
|
||||
borderRadius="0.5rem"
|
||||
marginBottom="0.5rem"
|
||||
width="max-content"
|
||||
fontSize="14"
|
||||
>
|
||||
{t('modelmanager:diffusersModels')}
|
||||
</Text>
|
||||
{diffusersModelListItemsToRender}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'ckpt' && (
|
||||
<Flex flexDirection="column" marginTop="1rem">
|
||||
{ckptModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'diffusers' && (
|
||||
<Flex flexDirection="column" marginTop="1rem">
|
||||
{diffusersModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}, [models, searchText, t, isSelectedFilter]);
|
||||
|
||||
return (
|
||||
<Flex flexDirection={'column'} rowGap="2rem" width="50%" minWidth="50%">
|
||||
@ -93,7 +197,24 @@ const ModelList = () => {
|
||||
overflow={'scroll'}
|
||||
paddingRight="1rem"
|
||||
>
|
||||
{renderModelListItems()}
|
||||
<Flex columnGap="0.5rem">
|
||||
<ModelFilterButton
|
||||
label={t('modelmanager:allModels')}
|
||||
onClick={() => setIsSelectedFilter('all')}
|
||||
isActive={isSelectedFilter === 'all'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelmanager:checkpointModels')}
|
||||
onClick={() => setIsSelectedFilter('ckpt')}
|
||||
isActive={isSelectedFilter === 'ckpt'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelmanager:diffusersModels')}
|
||||
onClick={() => setIsSelectedFilter('diffusers')}
|
||||
isActive={isSelectedFilter === 'diffusers'}
|
||||
/>
|
||||
</Flex>
|
||||
{renderModelListItems}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -8,13 +8,17 @@ import {
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import React, { cloneElement } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import ModelEdit from './ModelEdit';
|
||||
import ModelList from './ModelList';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
import ModelList from './ModelList';
|
||||
import DiffusersModelEdit from './DiffusersModelEdit';
|
||||
import CheckpointModelEdit from './CheckpointModelEdit';
|
||||
|
||||
type ModelManagerModalProps = {
|
||||
children: ReactElement;
|
||||
};
|
||||
@ -28,6 +32,14 @@ export default function ModelManagerModal({
|
||||
onClose: onModelManagerModalClose,
|
||||
} = useDisclosure();
|
||||
|
||||
const model_list = useAppSelector(
|
||||
(state: RootState) => state.system.model_list
|
||||
);
|
||||
|
||||
const openModel = useAppSelector(
|
||||
(state: RootState) => state.system.openModel
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
@ -41,16 +53,22 @@ export default function ModelManagerModal({
|
||||
size="6xl"
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent className=" modal">
|
||||
<ModalContent className="modal" fontFamily="Inter">
|
||||
<ModalCloseButton className="modal-close-btn" />
|
||||
<ModalHeader>{t('modelmanager:modelManager')}</ModalHeader>
|
||||
<ModalHeader fontWeight="bold">
|
||||
{t('modelmanager:modelManager')}
|
||||
</ModalHeader>
|
||||
<Flex
|
||||
padding={'0 1.5rem 1.5rem 1.5rem'}
|
||||
width="100%"
|
||||
columnGap={'2rem'}
|
||||
>
|
||||
<ModelList />
|
||||
<ModelEdit />
|
||||
{openModel && model_list[openModel]['format'] === 'diffusers' ? (
|
||||
<DiffusersModelEdit />
|
||||
) : (
|
||||
<CheckpointModelEdit />
|
||||
)}
|
||||
</Flex>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
|
@ -178,6 +178,7 @@ export default function SearchModels() {
|
||||
width: 512,
|
||||
height: 512,
|
||||
default: false,
|
||||
format: 'ckpt',
|
||||
};
|
||||
dispatch(addNewModel(modelFormat));
|
||||
});
|
||||
|
@ -13,10 +13,6 @@
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 1.4rem;
|
||||
}
|
||||
}
|
||||
|
||||
.site-header-right-side {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Link } from '@chakra-ui/react';
|
||||
import { Flex, Link, Text } from '@chakra-ui/react';
|
||||
|
||||
import { FaGithub, FaDiscord, FaBug, FaKeyboard, FaCube } from 'react-icons/fa';
|
||||
|
||||
@ -17,20 +17,34 @@ import LanguagePicker from './LanguagePicker';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdSettings } from 'react-icons/md';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import type { RootState } from 'app/store';
|
||||
|
||||
/**
|
||||
* Header, includes color mode toggle, settings button, status message.
|
||||
*/
|
||||
const SiteHeader = () => {
|
||||
const { t } = useTranslation();
|
||||
const appVersion = useAppSelector(
|
||||
(state: RootState) => state.system.app_version
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="site-header">
|
||||
<div className="site-header-left-side">
|
||||
<img src={InvokeAILogo} alt="invoke-ai-logo" />
|
||||
<h1>
|
||||
invoke <strong>ai</strong>
|
||||
</h1>
|
||||
<Flex alignItems="center" columnGap="0.6rem">
|
||||
<Text fontSize="1.4rem">
|
||||
invoke <strong>ai</strong>
|
||||
</Text>
|
||||
<Text
|
||||
fontWeight="bold"
|
||||
color="var(--text-color-secondary)"
|
||||
marginTop="0.2rem"
|
||||
>
|
||||
{appVersion}
|
||||
</Text>
|
||||
</Flex>
|
||||
</div>
|
||||
|
||||
<div className="site-header-right-side">
|
||||
|
20
installer/create_installer.sh
Executable file → Normal file
20
installer/create_installer.sh
Executable file → Normal file
@ -2,23 +2,27 @@
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
VERSION=$(grep ^VERSION ../setup.py | awk '{ print $3 }' | sed "s/'//g" )
|
||||
VERSION=$(cd ..; python -c "from ldm.invoke import __version__ as version; print(version)")
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
|
||||
echo Building installer for version $VERSION
|
||||
echo "Be certain that you're in the 'installer' directory before continuing."
|
||||
read -p "Press any key to continue, or CTRL-C to exit..."
|
||||
|
||||
git commit -a
|
||||
read -e -p "Commit and tag this repo with ${VERSION} and 'latest'? [n]: " input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
git commit -a
|
||||
|
||||
if ! git tag $VERSION ; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
if ! git tag $VERSION ; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
git push origin :refs/tags/latest
|
||||
git tag -fa latest
|
||||
fi
|
||||
|
||||
git push origin :refs/tags/latest
|
||||
git tag -fa latest
|
||||
|
||||
echo Building installer zip fles for InvokeAI $VERSION
|
||||
|
||||
# get rid of any old ones
|
||||
|
@ -33,7 +33,7 @@ echo 1. Install python 3.9 or higher.
|
||||
echo 2. Double-click on the file WinLongPathsEnabled.reg in order to
|
||||
echo enable long path support on your system.
|
||||
echo 3. Install the Visual C++ core libraries.
|
||||
echo Pleaase download and install the libraries from:
|
||||
echo Please download and install the libraries from:
|
||||
echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170
|
||||
echo.
|
||||
echo See %INSTRUCTIONS% for more details.
|
||||
@ -78,6 +78,7 @@ set "selection="
|
||||
set /p selection=Select the path to install InvokeAI's directory into [%UserProfile%]:
|
||||
if not defined selection set selection=%UserProfile%
|
||||
set selection=%selection:"=%
|
||||
call :Trim selection !selection!
|
||||
set dest="%selection%\invokeai"
|
||||
if exist %dest% (
|
||||
set response=y
|
||||
@ -96,6 +97,8 @@ goto :pick_rootdir
|
||||
|
||||
set rootdir=%rootdir:"=%
|
||||
|
||||
|
||||
|
||||
@rem ---------------------- Initialize the runtime directory ---------------------
|
||||
echo.
|
||||
echo *** Creating Runtime Directory %rootdir% ***
|
||||
@ -150,6 +153,7 @@ if %errorlevel% neq 0 (
|
||||
goto :err_exit
|
||||
)
|
||||
POPD
|
||||
|
||||
copy .\templates\invoke.bat.in "%rootdir%\invoke.bat"
|
||||
copy .\templates\update.bat.in "%rootdir%\update.bat"
|
||||
|
||||
@ -217,3 +221,10 @@ pause
|
||||
exit /b
|
||||
|
||||
pause
|
||||
|
||||
:Trim
|
||||
SetLocal EnableDelayedExpansion
|
||||
set Params=%*
|
||||
for /f "tokens=1*" %%a in ("!Params!") do EndLocal & set %1=%%b
|
||||
exit /b
|
||||
|
||||
|
@ -9,8 +9,11 @@ set INVOKEAI_ROOT=.
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo 3. open the developer console
|
||||
set /P restore="Please enter 1, 2 or 3: "
|
||||
echo 3. run textual inversion training
|
||||
echo 4. open the developer console
|
||||
echo 5. re-run the configure script to download new models
|
||||
set /P restore="Please enter 1, 2, 3, 4 or 5: [5] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
echo Starting the InvokeAI command-line..
|
||||
python .venv\Scripts\invoke.py %*
|
||||
@ -18,6 +21,9 @@ IF /I "%restore%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invoke.py --web %*
|
||||
) ELSE IF /I "%restore%" == "3" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\textual_inversion_fe.py --web %*
|
||||
) ELSE IF /I "%restore%" == "4" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@ -29,6 +35,9 @@ IF /I "%restore%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%restore%" == "5" (
|
||||
echo Running configure_invokeai.py...
|
||||
python .venv\Scripts\configure_invokeai.py --web %*
|
||||
) ELSE (
|
||||
echo Invalid selection
|
||||
pause
|
||||
|
@ -19,12 +19,17 @@ if [ "$0" != "bash" ]; then
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line"
|
||||
echo "2. browser-based UI"
|
||||
echo "3. open the developer console"
|
||||
read -p "Please enter 1, 2, or 3: " yn
|
||||
case $yn in
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. open the developer console"
|
||||
echo "5. re-run the configure script to download new models"
|
||||
read -p "Please enter 1, 2, 3, 4 or 5: [1] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1 ) printf "\nStarting the InvokeAI command-line..\n"; .venv/bin/python .venv/bin/invoke.py $*;;
|
||||
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; .venv/bin/python .venv/bin/invoke.py --web $*;;
|
||||
3 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
||||
3 ) printf "\nStarting Textual Inversion:\n"; .venv/bin/python .venv/bin/textual_inversion_fe.py $*;;
|
||||
4 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
||||
5 ) printf "\nRunning configure_invokeai.py:\n"; .venv/bin/python .venv/bin/configure_invokeai.py $*;;
|
||||
* ) echo "Invalid selection"; exit;;
|
||||
esac
|
||||
else # in developer console
|
||||
|
@ -23,6 +23,7 @@ if "%arg%" neq "" (
|
||||
|
||||
set INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/!INVOKE_AI_VERSION!.zip"
|
||||
set INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/!INVOKE_AI_VERSION!/environments-and-requirements/requirements-base.txt
|
||||
set INVOKE_AI_MODELS=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/configs/INITIAL_MODELS.yaml
|
||||
|
||||
call curl -I "%INVOKE_AI_DEP%" -fs >.tmp.out
|
||||
if %errorlevel% neq 0 (
|
||||
@ -38,6 +39,8 @@ echo If you do not want to do this, press control-C now!
|
||||
pause
|
||||
|
||||
call curl -L "%INVOKE_AI_DEP%" > environments-and-requirements/requirements-base.txt
|
||||
call curl -L "%INVOKE_AI_MODELS%" > configs/INITIAL_MODELS.yaml
|
||||
|
||||
|
||||
call .venv\Scripts\activate.bat
|
||||
call .venv\Scripts\python -mpip install -r requirements.txt
|
||||
|
@ -18,6 +18,7 @@ INVOKE_AI_VERSION=${1:-latest}
|
||||
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/$INVOKE_AI_VERSION.zip"
|
||||
INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/environments-and-requirements/requirements-base.txt
|
||||
INVOKE_AI_MODELS=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/configs/INITIAL_MODELS.yaml
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
@ -44,6 +45,7 @@ echo If you do not want to do this, press control-C now!
|
||||
read -p "Press any key to continue, or CTRL-C to exit..."
|
||||
|
||||
curl -L "$INVOKE_AI_DEP" > environments-and-requirements/requirements-base.txt
|
||||
curl -L "$INVOKE_AI_MODELS" > configs/INITIAL_MODELS.yaml
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
|
248
ldm/generate.py
248
ldm/generate.py
@ -1,48 +1,44 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
import pyparsing
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import gc
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import transformers
|
||||
import io
|
||||
import gc
|
||||
import hashlib
|
||||
|
||||
import cv2
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import skimage
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ldm.invoke.conditioning
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
import torch
|
||||
import transformers
|
||||
from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.invoke.pngwriter import PngWriter
|
||||
import ldm.invoke.conditioning
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from ldm.invoke.concepts_lib import Concepts
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
@ -160,12 +156,12 @@ class Generate:
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
self.model_manager = None
|
||||
self.iterations = 1
|
||||
self.steps = 50
|
||||
self.cfg_scale = 7.5
|
||||
self.sampler_name = sampler_name
|
||||
self.ddim_eta = 0.0 # same seed always produces same image
|
||||
self.ddim_eta = ddim_eta # same seed always produces same image
|
||||
self.precision = precision
|
||||
self.strength = 0.75
|
||||
self.seamless = False
|
||||
@ -177,7 +173,6 @@ class Generate:
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.session_peakmem = None
|
||||
self.generators = {}
|
||||
self.base_generator = None
|
||||
self.seed = None
|
||||
self.outdir = outdir
|
||||
@ -208,8 +203,14 @@ class Generate:
|
||||
self.precision = choose_precision(self.device)
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||
# don't accept invalid models
|
||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||
model = model or fallback
|
||||
if not self.model_manager.valid_model(model):
|
||||
print(f'** "{model}" is not a known model name; falling back to {fallback}.')
|
||||
model = None
|
||||
self.model_name = model or fallback
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
@ -225,7 +226,7 @@ class Generate:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = os.path.join(Globals.root,'models',safety_model_id)
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
@ -404,7 +405,11 @@ class Generate:
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
configure_model_padding(model.unet, seamless, seamless_axes)
|
||||
configure_model_padding(model.vae, seamless, seamless_axes)
|
||||
else:
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert threshold >= 0.0, '--threshold must be >=0.0'
|
||||
@ -439,7 +444,7 @@ class Generate:
|
||||
self._set_sampler()
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt = self.concept_lib().replace_concepts_with_triggers(prompt, lambda concepts: self.load_concepts(concepts))
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
@ -546,7 +551,7 @@ class Generate:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError as e:
|
||||
except RuntimeError:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
|
||||
@ -558,7 +563,7 @@ class Generate:
|
||||
)
|
||||
if self._has_cuda():
|
||||
print(
|
||||
f'>> Max VRAM used for this generation:',
|
||||
'>> Max VRAM used for this generation:',
|
||||
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'Current VRAM utilization:',
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
@ -568,7 +573,7 @@ class Generate:
|
||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
||||
)
|
||||
print(
|
||||
f'>> Max VRAM used since script start: ',
|
||||
'>> Max VRAM used since script start: ',
|
||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||
)
|
||||
return results
|
||||
@ -613,9 +618,9 @@ class Generate:
|
||||
# used by multiple postfixers
|
||||
# todo: cross-attention control
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
prompt, model=self.model,
|
||||
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||
log_tokens =ldm.invoke.conditioning.log_tokenization
|
||||
log_tokens=ldm.invoke.conditioning.log_tokenization
|
||||
)
|
||||
|
||||
if tool in ('gfpgan','codeformer','upscale'):
|
||||
@ -644,7 +649,7 @@ class Generate:
|
||||
try:
|
||||
extend_instructions[direction]=int(pixels)
|
||||
except ValueError:
|
||||
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||
print('** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||
|
||||
opt.seed = seed
|
||||
opt.prompt = prompt
|
||||
@ -692,7 +697,7 @@ class Generate:
|
||||
)
|
||||
|
||||
elif tool is None:
|
||||
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
||||
print('* please provide at least one postprocessing option, such as -G or -U')
|
||||
return None
|
||||
else:
|
||||
print(f'* postprocessing tool {tool} is not yet supported')
|
||||
@ -769,75 +774,62 @@ class Generate:
|
||||
|
||||
return init_image,init_mask
|
||||
|
||||
# lots o' repeated code here! Turn into a make_func()
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
from ldm.invoke.generator import Generator
|
||||
self.generators['base'] = Generator(self.model, self.precision)
|
||||
return self.generators['base']
|
||||
|
||||
def _make_img2img(self):
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_embiggen(self):
|
||||
if not self.generators.get('embiggen'):
|
||||
from ldm.invoke.generator.embiggen import Embiggen
|
||||
self.generators['embiggen'] = Embiggen(self.model, self.precision)
|
||||
return self.generators['embiggen']
|
||||
return self._load_generator('','Generator')
|
||||
|
||||
def _make_txt2img(self):
|
||||
if not self.generators.get('txt2img'):
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
self.generators['txt2img'] = Txt2Img(self.model, self.precision)
|
||||
self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['txt2img']
|
||||
return self._load_generator('.txt2img','Txt2Img')
|
||||
|
||||
def _make_img2img(self):
|
||||
return self._load_generator('.img2img','Img2Img')
|
||||
|
||||
def _make_embiggen(self):
|
||||
return self._load_generator('.embiggen','Embiggen')
|
||||
|
||||
def _make_txt2img2img(self):
|
||||
if not self.generators.get('txt2img2'):
|
||||
from ldm.invoke.generator.txt2img2img import Txt2Img2Img
|
||||
self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision)
|
||||
self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['txt2img2']
|
||||
return self._load_generator('.txt2img2img','Txt2Img2Img')
|
||||
|
||||
def _make_inpaint(self):
|
||||
if not self.generators.get('inpaint'):
|
||||
from ldm.invoke.generator.inpaint import Inpaint
|
||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||
self.generators['inpaint'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['inpaint']
|
||||
return self._load_generator('.inpaint','Inpaint')
|
||||
|
||||
# "omnibus" supports the runwayML custom inpainting model, which does
|
||||
# txt2img, img2img and inpainting using slight variations on the same code
|
||||
def _make_omnibus(self):
|
||||
if not self.generators.get('omnibus'):
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['omnibus']
|
||||
return self._load_generator('.omnibus','Omnibus')
|
||||
|
||||
def _load_generator(self, module, class_name):
|
||||
if self.is_legacy_model(self.model_name):
|
||||
mn = f'ldm.invoke.ckpt_generator{module}'
|
||||
cn = f'Ckpt{class_name}'
|
||||
else:
|
||||
mn = f'ldm.invoke.generator{module}'
|
||||
cn = class_name
|
||||
module = importlib.import_module(mn)
|
||||
constructor = getattr(module,cn)
|
||||
return constructor(self.model, self.precision)
|
||||
|
||||
def load_model(self):
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
'''
|
||||
self.set_model(self.model_name)
|
||||
return self.set_model(self.model_name)
|
||||
|
||||
def set_model(self,model_name):
|
||||
"""
|
||||
Given the name of a model defined in models.yaml, will load and initialize it
|
||||
and return the model object. Previously-used models will be cached.
|
||||
|
||||
If the passed model_name is invalid, raises a KeyError.
|
||||
If the model fails to load for some reason, will attempt to load the previously-
|
||||
loaded model (if any). If that fallback fails, will raise an AssertionError
|
||||
"""
|
||||
if self.model_name == model_name and self.model is not None:
|
||||
return self.model
|
||||
|
||||
previous_model_name = self.model_name
|
||||
|
||||
# the model cache does the loading and offloading
|
||||
cache = self.model_cache
|
||||
cache = self.model_manager
|
||||
if not cache.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return self.model
|
||||
raise KeyError('** "{model_name}" is not a known model name. Cannot change.')
|
||||
|
||||
cache.print_vram_usage()
|
||||
|
||||
@ -847,11 +839,17 @@ class Generate:
|
||||
self.sampler = None
|
||||
self.generators = {}
|
||||
gc.collect()
|
||||
|
||||
model_data = cache.get_model(model_name)
|
||||
if model_data is None: # restore previous
|
||||
model_data = cache.get_model(self.model_name)
|
||||
model_name = self.model_name # addresses Issue #1547
|
||||
try:
|
||||
model_data = cache.get_model(model_name)
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
if previous_model_name is None:
|
||||
raise e
|
||||
print(f'** trying to reload previous model')
|
||||
model_data = cache.get_model(previous_model_name) # load previous
|
||||
if model_data is None:
|
||||
raise e
|
||||
model_name = previous_model_name
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
@ -863,19 +861,23 @@ class Generate:
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
self.model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
self.model.textual_inversion_manager.load_textual_inversion(ti_path,
|
||||
defer_injecting_tokens=True)
|
||||
print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}')
|
||||
|
||||
self._set_sampler()
|
||||
self.model_name = model_name
|
||||
self._set_sampler() # requires self.model_name to be set first
|
||||
return self.model
|
||||
|
||||
def load_concepts(self,concepts:list[str]):
|
||||
self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast')
|
||||
def load_huggingface_concepts(self, concepts:list[str]):
|
||||
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
def concept_lib(self)->Concepts:
|
||||
return self.model.embedding_manager.concepts_library
|
||||
@property
|
||||
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
|
||||
return self.model.textual_inversion_manager.hf_concepts_library
|
||||
|
||||
def correct_colors(self,
|
||||
image_list,
|
||||
@ -970,9 +972,18 @@ class Generate:
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
return self._make_base().sample_to_lowres_estimated_image(samples)
|
||||
|
||||
def is_legacy_model(self,model_name)->bool:
|
||||
return self.model_manager.is_legacy(model_name)
|
||||
|
||||
def _set_sampler(self):
|
||||
if isinstance(self.model, DiffusionPipeline):
|
||||
return self._set_scheduler()
|
||||
else:
|
||||
return self._set_sampler_legacy()
|
||||
|
||||
# very repetitive code - can this be simplified? The KSampler names are
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
def _set_sampler_legacy(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
@ -1000,6 +1011,41 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
def _set_scheduler(self):
|
||||
default = self.model.scheduler
|
||||
|
||||
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
||||
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
||||
# provide an alias for compatibility.
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
if self.sampler_name in scheduler_map:
|
||||
sampler_class = scheduler_map[self.sampler_name]
|
||||
msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})'
|
||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||
else:
|
||||
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
||||
f'Defaulting to {default}')
|
||||
self.sampler = default
|
||||
|
||||
print(msg)
|
||||
|
||||
if not hasattr(self.sampler, 'uses_inpainting_model'):
|
||||
# FIXME: terrible kludge!
|
||||
self.sampler.uses_inpainting_model = lambda: False
|
||||
|
||||
def _load_img(self, img)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
|
@ -2,11 +2,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import shlex
|
||||
import copy
|
||||
import warnings
|
||||
import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.generate import Generate
|
||||
@ -16,9 +12,9 @@ from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_f
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from ldm.invoke.concepts_lib import Concepts
|
||||
from omegaconf import OmegaConf
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from pathlib import Path
|
||||
from argparse import Namespace
|
||||
import pyparsing
|
||||
import ldm.invoke
|
||||
|
||||
@ -45,14 +41,20 @@ def main():
|
||||
print('--max_loaded_models must be >= 1; using 1')
|
||||
args.max_loaded_models = 1
|
||||
|
||||
# alert - setting a global here
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
Globals.always_use_cpu = args.always_use_cpu
|
||||
Globals.internet_available = args.internet_available and check_internet()
|
||||
print(f'>> Internet connectivity is {Globals.internet_available}')
|
||||
|
||||
if not args.conf:
|
||||
if not os.path.exists(os.path.join(Globals.root,'configs','models.yaml')):
|
||||
print(f"\n** Error. The file {os.path.join(Globals.root,'configs','models.yaml')} could not be found.")
|
||||
print(f'** Please check the location of your invokeai directory and use the --root_dir option to point to the correct path.')
|
||||
print(f'** This script will now exit.')
|
||||
print('** Please check the location of your invokeai directory and use the --root_dir option to point to the correct path.')
|
||||
print('** This script will now exit.')
|
||||
sys.exit(-1)
|
||||
|
||||
print(f'>> {ldm.invoke.__app_name__} {ldm.invoke.__version__}')
|
||||
print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}')
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# loading here to avoid long delays on startup
|
||||
@ -78,6 +80,9 @@ def main():
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# load the infile as a list of lines
|
||||
if opt.infile:
|
||||
try:
|
||||
@ -107,9 +112,8 @@ def main():
|
||||
safety_checker=opt.safety_checker,
|
||||
max_loaded_models=opt.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError):
|
||||
emergency_model_reconfigure(opt)
|
||||
sys.exit(-1)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt,e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
@ -120,9 +124,18 @@ def main():
|
||||
# preload the model
|
||||
try:
|
||||
gen.load_model()
|
||||
except AssertionError:
|
||||
emergency_model_reconfigure(opt)
|
||||
sys.exit(-1)
|
||||
except KeyError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
report_model_error(opt, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := opt.autoconvert:
|
||||
gen.model_manager.autoconvert_weights(
|
||||
conf_path=opt.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
# web server loops forever
|
||||
if opt.web or opt.gui:
|
||||
@ -138,6 +151,9 @@ def main():
|
||||
main_loop(gen, opt)
|
||||
except KeyboardInterrupt:
|
||||
print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}')
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
|
||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||
def main_loop(gen, opt):
|
||||
@ -147,14 +163,14 @@ def main_loop(gen, opt):
|
||||
doneAfterInFile = infile is not None
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
last_results = list()
|
||||
model_config = OmegaConf.load(opt.conf)
|
||||
|
||||
# The readline completer reads history from the .dream_history file located in the
|
||||
# output directory specified at the time of script launch. We do not currently support
|
||||
# changing the history file midstream when the output directory is changed.
|
||||
completer = get_completer(opt, models=list(model_config.keys()))
|
||||
completer = get_completer(opt, models=gen.model_manager.list_models())
|
||||
set_default_output_dir(opt, completer)
|
||||
add_embedding_terms(gen, completer)
|
||||
if gen.model:
|
||||
add_embedding_terms(gen, completer)
|
||||
output_cntr = completer.get_current_history_length()+1
|
||||
|
||||
# os.pathconf is not available on Windows
|
||||
@ -170,7 +186,7 @@ def main_loop(gen, opt):
|
||||
operation = 'generate'
|
||||
|
||||
try:
|
||||
command = get_next_command(infile)
|
||||
command = get_next_command(infile, gen.model_name)
|
||||
except EOFError:
|
||||
done = infile is None or doneAfterInFile
|
||||
infile = None
|
||||
@ -315,7 +331,7 @@ def main_loop(gen, opt):
|
||||
if use_prefix is not None:
|
||||
prefix = use_prefix
|
||||
postprocessed = upscaled if upscaled else operation=='postprocess'
|
||||
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
|
||||
opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
|
||||
filename, formatted_dream_prompt = prepare_image_metadata(
|
||||
opt,
|
||||
prefix,
|
||||
@ -434,24 +450,50 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
|
||||
elif command.startswith('!switch'):
|
||||
model_name = command.replace('!switch ','',1)
|
||||
gen.set_model(model_name)
|
||||
add_embedding_terms(gen, completer)
|
||||
try:
|
||||
gen.set_model(model_name)
|
||||
add_embedding_terms(gen, completer)
|
||||
except KeyError as e:
|
||||
print(str(e))
|
||||
except Exception as e:
|
||||
report_model_error(opt,e)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!models'):
|
||||
gen.model_cache.print_models()
|
||||
gen.model_manager.print_models()
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!import'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide a path to a .ckpt or .vae model file')
|
||||
elif not os.path.exists(path[1]):
|
||||
print(f'** {path[1]}: file not found')
|
||||
print('** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1')
|
||||
else:
|
||||
add_weights_to_config(path[1], gen, opt, completer)
|
||||
import_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!convert'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide the path to a .ckpt or .safetensors model')
|
||||
elif not os.path.exists(path[1]):
|
||||
print(f'** {path[1]}: model not found')
|
||||
else:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
|
||||
elif command.startswith('!optimize'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide an installed model name')
|
||||
elif not path[1] in gen.model_manager.list_models():
|
||||
print(f'** {path[1]}: model not found')
|
||||
else:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
@ -460,7 +502,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
if len(path) < 2:
|
||||
print('** please provide the name of a model')
|
||||
else:
|
||||
edit_config(path[1], gen, opt, completer)
|
||||
edit_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
@ -521,121 +563,223 @@ def set_default_output_dir(opt:Args, completer:Completer):
|
||||
completer.set_default_dir(opt.outdir)
|
||||
|
||||
|
||||
def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
print(f'>> Model import in process. Please enter the values needed to configure this model:')
|
||||
print()
|
||||
def import_model(model_path:str, gen, opt, completer):
|
||||
'''
|
||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
|
||||
(3) a huggingface repository id
|
||||
'''
|
||||
model_name = None
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
elif os.path.isdir(model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
else:
|
||||
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
|
||||
|
||||
new_config = {}
|
||||
new_config['weights'] = model_path
|
||||
if not model_name:
|
||||
return
|
||||
|
||||
if not _verify_load(model_name, gen):
|
||||
print('** model failed to load. Discarding configuration entry')
|
||||
gen.model_manager.del_model(model_name)
|
||||
return
|
||||
|
||||
if input('Make this the default model? [n] ') in ('y','Y'):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
model_name = input('Short name for this model: ')
|
||||
if not re.match('^[\w._-]+$',model_name):
|
||||
print('** model name must contain only words, digits and the characters [._-] **')
|
||||
else:
|
||||
done = True
|
||||
new_config['description'] = input('Description of this model: ')
|
||||
gen.model_manager.commit(opt.conf)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f'>> {model_name} successfully installed')
|
||||
|
||||
def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_repo).stem
|
||||
default_description = f'Imported model {default_name}'
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
path_or_repo,
|
||||
model_name = model_name,
|
||||
description = model_description):
|
||||
print('** model failed to import')
|
||||
return None
|
||||
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||
manager.set_default_model(model_name)
|
||||
return model_name
|
||||
|
||||
def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_url).stem
|
||||
default_description = f'Imported model {default_name}'
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
config_file = None
|
||||
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
|
||||
|
||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
||||
done = False
|
||||
while not done:
|
||||
new_config['config'] = input('Configuration file for this model: ')
|
||||
done = os.path.exists(new_config['config'])
|
||||
|
||||
done = False
|
||||
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
||||
while not done:
|
||||
vae = input('VAE autoencoder file for this model [None]: ')
|
||||
if os.path.exists(vae):
|
||||
new_config['vae'] = vae
|
||||
done = True
|
||||
else:
|
||||
done = len(vae)==0
|
||||
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
done = os.path.exists(config_file)
|
||||
completer.complete_extensions(None)
|
||||
|
||||
for field in ('width','height'):
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
completer.linebuffer = '512'
|
||||
value = int(input(f'Default image {field}: '))
|
||||
assert value >= 64 and value <= 2048
|
||||
new_config[field] = value
|
||||
done = True
|
||||
except:
|
||||
print('** Please enter a valid integer between 64 and 2048')
|
||||
if not manager.import_ckpt_model(
|
||||
path_or_url,
|
||||
config = config_file,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf,
|
||||
):
|
||||
print('** model failed to import')
|
||||
return None
|
||||
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||
manager.set_model_default(model_name)
|
||||
return model_name
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||
completer.add_model(model_name)
|
||||
def _verify_load(model_name:str, gen)->bool:
|
||||
print('>> Verifying that new model loads...')
|
||||
current_model = gen.model_name
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
return False
|
||||
do_switch = input('Keep model loaded? [y] ')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
gen.set_model(model_name)
|
||||
else:
|
||||
print('>> Restoring previous model')
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''):
|
||||
model_name = _get_model_name(model_manager.list_models(),completer,model_name)
|
||||
completer.set_line(model_description)
|
||||
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
|
||||
return model_name, model_description
|
||||
|
||||
def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
|
||||
if (model_info := manager.model_info(model_name_or_path)):
|
||||
if 'weights' in model_info:
|
||||
ckpt_path = Path(model_info['weights'])
|
||||
model_name = model_name_or_path
|
||||
model_description = model_info['description']
|
||||
else:
|
||||
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
|
||||
return
|
||||
elif os.path.exists(model_name_or_path):
|
||||
ckpt_path = Path(model_name_or_path)
|
||||
model_name,model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
ckpt_path.stem,
|
||||
f'Converted model {ckpt_path.stem}'
|
||||
)
|
||||
else:
|
||||
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||
return
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root,ckpt_path)
|
||||
|
||||
diffuser_path = Path(Globals.root, 'models','optimized-ckpts',model_name)
|
||||
if diffuser_path.exists():
|
||||
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
|
||||
return
|
||||
|
||||
new_config = gen.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffuser_path,
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
commit_to_conf=opt.conf,
|
||||
)
|
||||
if not new_config:
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if input(f'Load optimized model {model_name}? [y] ') not in ('n','N'):
|
||||
gen.set_model(model_name)
|
||||
|
||||
response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
|
||||
if response.startswith(('y','Y')):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f'{ckpt_path} deleted')
|
||||
|
||||
def del_config(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
gen.model_cache.del_model(model_name)
|
||||
gen.model_cache.commit(opt.conf)
|
||||
gen.model_manager.del_model(model_name)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
|
||||
def edit_config(model_name:str, gen, opt, completer):
|
||||
config = gen.model_cache.config
|
||||
def edit_model(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
# if model_name == current_model:
|
||||
# print("** Can't edit the active model. !switch to another model first. **")
|
||||
# return
|
||||
|
||||
if model_name not in config:
|
||||
manager = gen.model_manager
|
||||
if not (info := manager.model_info(model_name)):
|
||||
print(f'** Unknown model {model_name}')
|
||||
return
|
||||
|
||||
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
||||
new_name = _get_model_name(manager.list_models(),completer,model_name)
|
||||
|
||||
conf = config[model_name]
|
||||
new_config = {}
|
||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||
new_value = input(f'{field}: ')
|
||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
completer.complete_extensions(None)
|
||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||
for attribute in info.keys():
|
||||
if type(info[attribute]) != str:
|
||||
continue
|
||||
if attribute == 'format':
|
||||
continue
|
||||
completer.set_line(info[attribute])
|
||||
info[attribute] = input(f'{attribute}: ') or info[attribute]
|
||||
|
||||
if new_name != model_name:
|
||||
manager.del_model(model_name)
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||
current_model = gen.model_name
|
||||
# this does the update
|
||||
manager.add_model(new_name, info, True)
|
||||
|
||||
op = 'modify' if clobber else 'import'
|
||||
print('\n>> New configuration:')
|
||||
if make_default:
|
||||
new_config['default'] = True
|
||||
print(yaml.dump({model_name:new_config}))
|
||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||
return False
|
||||
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||
manager.set_default_model(new_name)
|
||||
manager.commit(opt.conf)
|
||||
completer.update_models(manager.list_models())
|
||||
print('>> Model successfully updated')
|
||||
|
||||
try:
|
||||
print('>> Verifying that new model loads...')
|
||||
gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||
except AssertionError as e:
|
||||
print(f'** aborting **')
|
||||
gen.model_cache.del_model(model_name)
|
||||
return False
|
||||
def _get_model_name(existing_names,completer,default_name:str='')->str:
|
||||
done = False
|
||||
completer.set_line(default_name)
|
||||
while not done:
|
||||
model_name = input(f'Short name for this model [{default_name}]: ').strip()
|
||||
if len(model_name)==0:
|
||||
model_name = default_name
|
||||
if not re.match('^[\w._+-]+$',model_name):
|
||||
print('** model name must contain only words, digits and the characters "._+-" **')
|
||||
elif model_name != default_name and model_name in existing_names:
|
||||
print(f'** the name {model_name} is already in use. Pick another.')
|
||||
else:
|
||||
done = True
|
||||
return model_name
|
||||
|
||||
if make_default:
|
||||
print('making this default')
|
||||
gen.model_cache.set_default_model(model_name)
|
||||
|
||||
gen.model_cache.commit(conf_path)
|
||||
|
||||
do_switch = input(f'Keep model loaded? [y]')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
pass
|
||||
else:
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
def do_textmask(gen, opt, callback):
|
||||
image_path = opt.prompt
|
||||
@ -746,7 +890,7 @@ def prepare_image_metadata(
|
||||
except KeyError as e:
|
||||
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
||||
filename = f'{prefix}.{seed}.png'
|
||||
except IndexError as e:
|
||||
except IndexError:
|
||||
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
||||
filename = f'{prefix}.{seed}.png'
|
||||
|
||||
@ -782,9 +926,9 @@ def choose_postprocess_name(opt,prefix,seed) -> str:
|
||||
counter += 1
|
||||
return filename
|
||||
|
||||
def get_next_command(infile=None) -> str: # command string
|
||||
def get_next_command(infile=None, model_name='no model') -> str: # command string
|
||||
if infile is None:
|
||||
command = input('invoke> ')
|
||||
command = input(f'({model_name}) invoke> ').strip()
|
||||
else:
|
||||
command = infile.readline()
|
||||
if not command:
|
||||
@ -815,7 +959,8 @@ def add_embedding_terms(gen,completer):
|
||||
Called after setting the model, updates the autocompleter with
|
||||
any terms loaded by the embedding manager.
|
||||
'''
|
||||
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
|
||||
trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings()
|
||||
completer.add_embedding_terms(trigger_strings)
|
||||
|
||||
def split_variations(variations_string) -> list:
|
||||
# shotgun parsing, woo
|
||||
@ -938,13 +1083,13 @@ def write_commands(opt, file_path:str, outfilepath:str):
|
||||
f.write('\n'.join(commands))
|
||||
print(f'>> File {outfilepath} with commands created')
|
||||
|
||||
def emergency_model_reconfigure(opt):
|
||||
print()
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
print(' You appear to have a missing or misconfigured model file(s). ')
|
||||
print(' The script will now exit and run configure_invokeai.py to help fix the problem.')
|
||||
print(' After reconfiguration is done, please relaunch invoke.py. ')
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
def report_model_error(opt:Namespace, e:Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
|
||||
response = input('Do you want to run configure_invokeai.py to select and/or reinstall models? [y] ')
|
||||
if response.startswith(('n','N')):
|
||||
return
|
||||
|
||||
print('configure_invokeai is launching....\n')
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
@ -952,7 +1097,7 @@ def emergency_model_reconfigure(opt):
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
|
||||
|
||||
previous_args = sys.argv
|
||||
sys.argv = [ 'configure_invokeai' ]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config)
|
||||
@ -961,3 +1106,20 @@ def emergency_model_reconfigure(opt):
|
||||
|
||||
import configure_invokeai
|
||||
configure_invokeai.main()
|
||||
print('** InvokeAI will now restart')
|
||||
sys.argv = previous_args
|
||||
main() # would rather do a os.exec(), but doesn't exist?
|
||||
sys.exit(0)
|
||||
|
||||
def check_internet()->bool:
|
||||
'''
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
'''
|
||||
import urllib.request
|
||||
host = 'http://huggingface.co'
|
||||
try:
|
||||
urllib.request.urlopen(host,timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
@ -1,3 +1,4 @@
|
||||
from ._version import __version__
|
||||
|
||||
__app_id__= 'invoke-ai/InvokeAI'
|
||||
__app_name__= 'InvokeAI'
|
||||
__version__='2.2.5'
|
||||
|
1
ldm/invoke/_version.py
Normal file
1
ldm/invoke/_version.py
Normal file
@ -0,0 +1 @@
|
||||
__version__='2.3.0+a0'
|
@ -81,22 +81,23 @@ with metadata_from_png():
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace, RawTextHelpFormatter
|
||||
import pydoc
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import shlex
|
||||
import copy
|
||||
import base64
|
||||
import copy
|
||||
import functools
|
||||
import warnings
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pydoc
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import ldm.invoke
|
||||
import ldm.invoke.pngwriter
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
APP_ID = ldm.invoke.__app_id__
|
||||
APP_NAME = ldm.invoke.__app_name__
|
||||
@ -113,6 +114,8 @@ SAMPLER_CHOICES = [
|
||||
'k_heun',
|
||||
'k_lms',
|
||||
'plms',
|
||||
# diffusers:
|
||||
"pndm",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
@ -181,7 +184,7 @@ class Args(object):
|
||||
sys.exit(0)
|
||||
|
||||
print('* Initializing, be patient...')
|
||||
Globals.root = os.path.abspath(switches.root_dir or Globals.root)
|
||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||
Globals.try_patchmatch = switches.patchmatch
|
||||
|
||||
# now use root directory to find the init file
|
||||
@ -273,7 +276,7 @@ class Args(object):
|
||||
switches.append(f'-I {a["init_img"]}')
|
||||
switches.append(f'-A {a["sampler_name"]}')
|
||||
if a['fit']:
|
||||
switches.append(f'--fit')
|
||||
switches.append('--fit')
|
||||
if a['init_mask'] and len(a['init_mask'])>0:
|
||||
switches.append(f'-M {a["init_mask"]}')
|
||||
if a['init_color'] and len(a['init_color'])>0:
|
||||
@ -281,7 +284,7 @@ class Args(object):
|
||||
if a['strength'] and a['strength']>0:
|
||||
switches.append(f'-f {a["strength"]}')
|
||||
if a['inpaint_replace']:
|
||||
switches.append(f'--inpaint_replace')
|
||||
switches.append('--inpaint_replace')
|
||||
if a['text_mask']:
|
||||
switches.append(f'-tm {" ".join([str(u) for u in a["text_mask"]])}')
|
||||
else:
|
||||
@ -479,6 +482,12 @@ class Args(object):
|
||||
action='store_true',
|
||||
help='Force free gpu memory before final decoding',
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--always_use_cpu",
|
||||
dest="always_use_cpu",
|
||||
action="store_true",
|
||||
help="Force use of CPU even if GPU is available"
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--precision',
|
||||
dest='precision',
|
||||
@ -489,13 +498,26 @@ class Args(object):
|
||||
default='auto',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--nsfw_checker'
|
||||
'--internet',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='internet_available',
|
||||
default=True,
|
||||
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--nsfw_checker',
|
||||
'--safety_checker',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='safety_checker',
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt weights files at startup and import as optimized diffuser models',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
@ -718,11 +740,15 @@ class Args(object):
|
||||
!NN retrieves the NNth command from the history
|
||||
|
||||
*Model manipulation*
|
||||
!models -- list models in configs/models.yaml
|
||||
!switch <model_name> -- switch to model named <model_name>
|
||||
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||
!edit_model <model_name> -- edit a model's description
|
||||
!del_model <model_name> -- delete a model
|
||||
!models -- list models in configs/models.yaml
|
||||
!switch <model_name> -- switch to model named <model_name>
|
||||
!import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config
|
||||
!import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config
|
||||
!import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config
|
||||
!optimize_model <model_name> -- converts a .ckpt model to a diffusers model
|
||||
!convert_model /path/to/weights/file.ckpt -- converts a .ckpt file path to a diffusers model
|
||||
!edit_model <model_name> -- edit a model's description
|
||||
!del_model <model_name> -- delete a model
|
||||
"""
|
||||
)
|
||||
render_group = parser.add_argument_group('General rendering')
|
||||
@ -1061,7 +1087,7 @@ class Args(object):
|
||||
return parser
|
||||
|
||||
def format_metadata(**kwargs):
|
||||
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
|
||||
print('format_metadata() is deprecated. Please use metadata_dumps()')
|
||||
return metadata_dumps(kwargs)
|
||||
|
||||
def metadata_dumps(opt,
|
||||
@ -1128,7 +1154,7 @@ def metadata_dumps(opt,
|
||||
rfc_dict.pop('strength')
|
||||
|
||||
if len(seeds)==0 and opt.seed:
|
||||
seeds=[seed]
|
||||
seeds=[opt.seed]
|
||||
|
||||
if opt.grid:
|
||||
images = []
|
||||
@ -1199,7 +1225,7 @@ def metadata_loads(metadata) -> list:
|
||||
opt = Args()
|
||||
opt._cmd_switches = Namespace(**image)
|
||||
results.append(opt)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import sys, traceback
|
||||
print('>> could not read metadata',file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
4
ldm/invoke/ckpt_generator/__init__.py
Normal file
4
ldm/invoke/ckpt_generator/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.invoke.generator package
|
||||
'''
|
||||
from .base import CkptGenerator
|
338
ldm/invoke/ckpt_generator/base.py
Normal file
338
ldm/invoke/ckpt_generator/base.py
Normal file
@ -0,0 +1,338 @@
|
||||
'''
|
||||
Base class for ldm.invoke.ckpt_generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
|
||||
THESE MODULES ARE TRANSITIONAL AND WILL BE REMOVED AT A FUTURE DATE
|
||||
WHEN LEGACY CKPT MODEL SUPPORT IS DISCONTINUED.
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import os.path as osp
|
||||
import traceback
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image, ImageFilter, ImageChops
|
||||
import cv2 as cv
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'assets/caution.png'
|
||||
|
||||
class CkptGenerator():
|
||||
def __init__(self, model, precision):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
self.caution_img = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
seed_everything(seed)
|
||||
target_noise = self.get_noise(width,height)
|
||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
x_T = initial_noise
|
||||
else:
|
||||
seed_everything(seed)
|
||||
try:
|
||||
x_T = self.get_noise(width,height)
|
||||
except:
|
||||
print('** An error occurred while getting initial noise **')
|
||||
print(traceback.format_exc())
|
||||
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1]
|
||||
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
"""
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
if len(x_samples) != 1:
|
||||
raise Exception(
|
||||
f'>> expected to get a single image, but got {len(x_samples)}')
|
||||
x_sample = 255.0 * rearrange(
|
||||
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||
)
|
||||
return Image.fromarray(x_sample.astype(np.uint8))
|
||||
|
||||
# write an approximate RGB image from latent samples for a single step to PNG
|
||||
|
||||
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image:
|
||||
if init_image is None or init_mask is None:
|
||||
return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L')
|
||||
pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(result, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
if np_init_rgb_pixels_masked.size > 0:
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(self,samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor([
|
||||
# R G B
|
||||
[ 0.3444, 0.1385, 0.0670], # L1
|
||||
[ 0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445] # L4
|
||||
], dtype=samples.dtype, device=samples.device)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
def generate_initial_noise(self, seed, width, height):
|
||||
initial_noise = None
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
seed_everything(seed)
|
||||
initial_noise = self.get_noise(width,height)
|
||||
for v_seed, v_weight in self.with_variations:
|
||||
seed = v_seed
|
||||
seed_everything(seed)
|
||||
next_noise = self.get_noise(width,height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if self.variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||
return (seed, initial_noise)
|
||||
else:
|
||||
return (seed, None)
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
"""
|
||||
Returns a tensor filled with random numbers, either form a normal distribution
|
||||
(txt2img) or from the latent image (img2img, inpaint)
|
||||
"""
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
'''
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
'''
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(self.model.device)
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
features = extractor([image], return_tensors="pt")
|
||||
features.to(self.model.device)
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = self.get_caution_img()
|
||||
if caution:
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
def get_caution_img(self):
|
||||
path = None
|
||||
if self.caution_img:
|
||||
return self.caution_img
|
||||
# Find the caution image. If we are installed in the package directory it will
|
||||
# be six levels up. If we are in the repo directory it will be three levels up.
|
||||
for dots in ('../../..','../../../../../..'):
|
||||
caution_path = osp.join(osp.dirname(__file__),dots,CAUTION_IMG)
|
||||
if osp.exists(caution_path):
|
||||
path = caution_path
|
||||
break
|
||||
if not path:
|
||||
return
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height //2))
|
||||
return self.caution_img
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or '.'
|
||||
if not os.path.exists(dirname):
|
||||
print(f'** creating directory {dirname}')
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
501
ldm/invoke/ckpt_generator/embiggen.py
Normal file
501
ldm/invoke/ckpt_generator/embiggen.py
Normal file
@ -0,0 +1,501 @@
|
||||
'''
|
||||
ldm.invoke.ckpt_generator.embiggen descends from ldm.invoke.ckpt_generator
|
||||
and generates with ldm.invoke.ckpt_generator.img2img
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from PIL import Image
|
||||
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||
from ldm.invoke.ckpt_generator.img2img import CkptImg2Img
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class CkptEmbiggen(CkptGenerator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(self,prompt,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
step_callback = step_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
with scope(self.model.device.type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = CkptImg2Img(self.model,self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert('RGB')
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||
width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||
height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = Image.linear_gradient('L').rotate(
|
||||
90).resize((overlap_size_x, height))
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
#Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x-(255-y)) * (255 / max(1,y)))
|
||||
#Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||
else:
|
||||
print(
|
||||
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||
else:
|
||||
print(
|
||||
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(
|
||||
newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations = 1,
|
||||
seed = seed,
|
||||
sampler = DDIMSampler(self.model, device=self.model.device),
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = conditioning,
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = None, # called only after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_image = newinitimage, # notice that init_image is different from init_img
|
||||
mask_image = None,
|
||||
strength = strength,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||
outputsuperimage = Image.new(
|
||||
"RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert('RGBA'), (0, 0))
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert('RGBA')
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i *
|
||||
(width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
# end of function declaration
|
||||
return make_image
|
97
ldm/invoke/ckpt_generator/img2img.py
Normal file
97
ldm/invoke/ckpt_generator/img2img.py
Normal file
@ -0,0 +1,97 @@
|
||||
'''
|
||||
ldm.invoke.ckpt_generator.img2img descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import PIL
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
class CkptImg2Img(CkptGenerator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc - 1]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
init_latent = self.init_latent
|
||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(init_latent, device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn_like(init_latent, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = init_latent.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.model.device)
|
358
ldm/invoke/ckpt_generator/inpaint.py
Normal file
358
ldm/invoke/ckpt_generator/inpaint.py
Normal file
@ -0,0 +1,358 @@
|
||||
'''
|
||||
ldm.invoke.ckpt_generator.inpaint descends from ldm.invoke.ckpt_generator
|
||||
'''
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import PIL
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.ckpt_generator.img2img import CkptImg2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.util import debug_image
|
||||
from ldm.invoke.patchmatch import PatchMatch
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def infill_methods()->list[str]:
|
||||
methods = list()
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.append('patchmatch')
|
||||
methods.append('tile')
|
||||
return methods
|
||||
|
||||
class CkptInpaint(CkptImg2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.init_latent = None
|
||||
self.pil_image = None
|
||||
self.pil_mask = None
|
||||
self.mask_blur_radius = 0
|
||||
self.infill_method = None
|
||||
super().__init__(model, precision)
|
||||
|
||||
# Outpaint support code
|
||||
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
if im.mode != 'RGBA':
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != 'RGBA':
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:,:,:,:,3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = (tiles_mask > 0)
|
||||
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed = seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1,2)
|
||||
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||
si = Image.fromarray(st, mode='RGBA')
|
||||
|
||||
return si
|
||||
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
||||
|
||||
# Combine
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
||||
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self,
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,strength,
|
||||
noise,
|
||||
step_callback
|
||||
) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0,
|
||||
step_callback = step_callback,
|
||||
inpaint_width = im.width,
|
||||
inpaint_height = im.height
|
||||
)
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False, enable_image_debugging=False,
|
||||
infill_method = None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
self.enable_image_debugging = enable_image_debugging
|
||||
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
|
||||
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available():
|
||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||
else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(),
|
||||
seed = self.seed,
|
||||
tile_size = tile_size
|
||||
)
|
||||
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||
|
||||
# Resize if requested for inpainting
|
||||
if inpaint_width and inpaint_height:
|
||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Create init tensor
|
||||
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
|
||||
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
|
||||
self.pil_mask = mask_image
|
||||
|
||||
# Resize if requested for inpainting
|
||||
if inpaint_width and inpaint_height:
|
||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling,
|
||||
mask_image.height // downsampling
|
||||
),
|
||||
resample=Image.Resampling.NEAREST
|
||||
)
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
if isinstance(sampler,KSampler):
|
||||
print(
|
||||
f">> Using recommended DDIM sampler for inpainting."
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc - 1]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
||||
if inpaint_replace > 0.0:
|
||||
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
||||
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
||||
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
||||
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
||||
z_enc = z_enc * mask_image + masked_region
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
|
||||
result = self.sample_to_image(samples)
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
old_image = self.pil_image or init_image
|
||||
old_mask = self.pil_mask or mask_image
|
||||
|
||||
result = self.seam_paint(
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
seam_strength,
|
||||
x_T,
|
||||
step_callback)
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,
|
||||
old_image,
|
||||
old_mask,
|
||||
strength,
|
||||
mask_blur_radius, seam_size, seam_blur, seam_strength,
|
||||
seam_steps, tile_size, step_callback,
|
||||
inpaint_replace, enable_image_debugging,
|
||||
inpaint_width = inpaint_width,
|
||||
inpaint_height = inpaint_height,
|
||||
infill_method = infill_method,
|
||||
**kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Resize if necessary
|
||||
if self.inpaint_width and self.inpaint_height:
|
||||
gen_result = gen_result.resize(self.pil_image.size)
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
return corrected_result
|
175
ldm/invoke/ckpt_generator/omnibus.py
Normal file
175
ldm/invoke/ckpt_generator/omnibus.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from PIL import Image, ImageOps, ImageChops
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.ckpt_generator.base import downsampling
|
||||
from ldm.invoke.ckpt_generator.img2img import CkptImg2Img
|
||||
from ldm.invoke.ckpt_generator.txt2img import CkptTxt2Img
|
||||
|
||||
class CkptOmnibus(CkptImg2Img,CkptTxt2Img):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.pil_mask = None
|
||||
self.pil_image = None
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
init_image = None,
|
||||
mask_image = None,
|
||||
strength = None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
mask_blur_radius: int = 8,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, Image.Image):
|
||||
self.pil_image = init_image
|
||||
if init_image.mode != 'RGB':
|
||||
init_image = init_image.convert('RGB')
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
|
||||
mask_image = ImageChops.multiply(mask_image.convert('L'), self.pil_image.split()[-1])
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image), normalize=False)
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
t_enc = steps
|
||||
|
||||
if init_image is not None and mask_image is not None: # inpainting
|
||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
masked_image = init_image
|
||||
|
||||
else: # txt2img
|
||||
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
self.init_latent = init_image
|
||||
height = init_image.shape[2]
|
||||
width = init_image.shape[3]
|
||||
model = self.model
|
||||
|
||||
def make_image(x_T):
|
||||
with torch.no_grad():
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
|
||||
batch = self.make_batch_sd(
|
||||
init_image,
|
||||
mask_image,
|
||||
masked_image,
|
||||
prompt=prompt,
|
||||
device=model.device,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != model.masked_image_key:
|
||||
bchw = [num_samples, 4, height//8, width//8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = cond,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc_full,
|
||||
eta = 1.0,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
)
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def make_batch_sd(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
masked_image,
|
||||
prompt,
|
||||
device,
|
||||
num_samples=1):
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [prompt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
if self.init_latent is not None:
|
||||
height = self.init_latent.shape[2]
|
||||
width = self.init_latent.shape[3]
|
||||
return CkptTxt2Img.get_noise(self,width,height)
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
if self.pil_image.size != self.pil_mask.size:
|
||||
return gen_result
|
||||
|
||||
corrected_result = super(CkptImg2Img, self).repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
|
||||
return corrected_result
|
88
ldm/invoke/ckpt_generator/txt2img.py
Normal file
88
ldm/invoke/ckpt_generator/txt2img.py
Normal file
@ -0,0 +1,88 @@
|
||||
'''
|
||||
ldm.invoke.ckpt_generator.txt2img inherits from ldm.invoke.ckpt_generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
import gc
|
||||
|
||||
|
||||
class CkptTxt2Img(CkptGenerator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to('cpu')
|
||||
self.model.cond_stage_model.device = 'cpu'
|
||||
self.model.cond_stage_model.to('cpu')
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
182
ldm/invoke/ckpt_generator/txt2img2img.py
Normal file
182
ldm/invoke/ckpt_generator/txt2img2img.py
Normal file
@ -0,0 +1,182 @@
|
||||
'''
|
||||
ldm.invoke.ckpt_generator.txt2img inherits from ldm.invoke.ckpt_generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
import gc
|
||||
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||
from ldm.invoke.ckpt_generator.omnibus import CkptOmnibus
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from PIL import Image
|
||||
|
||||
class CkptTxt2Img2Img(CkptGenerator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,strength,step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
scale_dim = min(width, height)
|
||||
scale = 512 / scale_dim
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
init_height // self.downsampling_factor,
|
||||
init_width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
#x = self.get_noise(init_width, init_height)
|
||||
x = x_T
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
samples = torch.nn.functional.interpolate(
|
||||
samples,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
ddim_sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
z_enc = ddim_sampler.stochastic_encode(
|
||||
samples,
|
||||
torch.tensor([t_enc-1]).to(self.model.device),
|
||||
noise=self.get_noise(width,height,False)
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = ddim_sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to('cpu')
|
||||
self.model.cond_stage_model.device = 'cpu'
|
||||
self.model.cond_stage_model.to('cpu')
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
def inpaint_make_image(x_T):
|
||||
omnibus = CkptOmnibus(self.model,self.precision)
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
width=init_width,
|
||||
height=init_height,
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||
image = result[0][0]
|
||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||
print(kwargs.pop('init_image',None))
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=interpolated_image,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=result[0][1],
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
return result[0][0]
|
||||
|
||||
if sampler.uses_inpainting_model():
|
||||
return inpaint_make_image
|
||||
else:
|
||||
return make_image
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
trained_square = 512 * 512
|
||||
actual_square = width * height
|
||||
scale = math.sqrt(trained_square / actual_square)
|
||||
scaled_width = math.ceil(scale * width / 64) * 64
|
||||
scaled_height = math.ceil(scale * height / 64) * 64
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
953
ldm/invoke/ckpt_to_diffuser.py
Normal file
953
ldm/invoke/ckpt_to_diffuser.py
Normal file
@ -0,0 +1,953 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Adapted for use as a module by Lincoln Stein <lstein@gmail.com>
|
||||
# Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import Globals
|
||||
from safetensors.torch import load_file
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||
else:
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item.replace("in_layers.0", "norm1")
|
||||
new_item = new_item.replace("in_layers.2", "conv1")
|
||||
|
||||
new_item = new_item.replace("out_layers.0", "norm2")
|
||||
new_item = new_item.replace("out_layers.3", "conv2")
|
||||
|
||||
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||
to them. It splits attention layers, and takes into account additional replacements
|
||||
that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim = [5, 10, 20, 20]
|
||||
|
||||
config = dict(
|
||||
sample_size=image_size // vae_scale_factor,
|
||||
in_channels=unet_params.in_channels,
|
||||
out_channels=unet_params.out_channels,
|
||||
down_block_types=tuple(down_block_types),
|
||||
up_block_types=tuple(up_block_types),
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
layers_per_block=unet_params.num_res_blocks,
|
||||
cross_attention_dim=unet_params.context_dim,
|
||||
attention_head_dim=head_dim,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_vae_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
config = dict(
|
||||
sample_size=image_size,
|
||||
in_channels=vae_params.in_channels,
|
||||
out_channels=vae_params.out_ch,
|
||||
down_block_types=tuple(down_block_types),
|
||||
up_block_types=tuple(up_block_types),
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
latent_channels=vae_params.z_channels,
|
||||
layers_per_block=vae_params.num_res_blocks,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def create_diffusers_schedular(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
beta_end=original_config.model.params.linear_end,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
return schedular
|
||||
|
||||
|
||||
def create_ldm_bert_config(original_config):
|
||||
bert_params = original_config.model.params.cond_stage_config.params
|
||||
config = LDMBertConfig(
|
||||
d_model=bert_params.n_embed,
|
||||
encoder_layers=bert_params.n_layer,
|
||||
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
print(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
resnets = [
|
||||
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||
]
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
resnet_1 = middle_blocks[2]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
||||
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(
|
||||
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||
output_block_list = {}
|
||||
|
||||
for layer in output_block_layers:
|
||||
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||
if layer_id in output_block_list:
|
||||
output_block_list[layer_id].append(layer_name)
|
||||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnets)
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
attentions = []
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||
|
||||
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||
|
||||
def _copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
def _copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||
|
||||
# copy attn
|
||||
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||
|
||||
# copy MLP
|
||||
pt_mlp = pt_layer[1][1]
|
||||
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||
|
||||
def _copy_layers(hf_layers, pt_layers):
|
||||
for i, hf_layer in enumerate(hf_layers):
|
||||
if i != 0:
|
||||
i += i
|
||||
pt_layer = pt_layers[i : i + 2]
|
||||
_copy_layer(hf_layer, pt_layer)
|
||||
|
||||
hf_model = LDMBertModel(config).eval()
|
||||
|
||||
# copy embeds
|
||||
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||
|
||||
# copy layer norm
|
||||
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||
|
||||
# copy hidden layers
|
||||
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||
|
||||
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
]
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
|
||||
textenc_transformer_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
# load clip vision
|
||||
model.model.load_state_dict(text_model_dict)
|
||||
|
||||
# load mapper
|
||||
keys_mapper = {
|
||||
k[len("cond_stage_model.mapper.res") :]: v
|
||||
for k, v in checkpoint.items()
|
||||
if k.startswith("cond_stage_model.mapper")
|
||||
}
|
||||
|
||||
MAPPING = {
|
||||
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
||||
"attn.c_proj": ["attn1.to_out.0"],
|
||||
"ln_1": ["norm1"],
|
||||
"ln_2": ["norm3"],
|
||||
"mlp.c_fc": ["ff.net.0.proj"],
|
||||
"mlp.c_proj": ["ff.net.2"],
|
||||
}
|
||||
|
||||
mapped_weights = {}
|
||||
for key, value in keys_mapper.items():
|
||||
prefix = key[: len("blocks.i")]
|
||||
suffix = key.split(prefix)[-1].split(".")[-1]
|
||||
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
||||
mapped_names = MAPPING[name]
|
||||
|
||||
num_splits = len(mapped_names)
|
||||
for i, mapped_name in enumerate(mapped_names):
|
||||
new_name = ".".join([prefix, mapped_name, suffix])
|
||||
shape = value.shape[0] // num_splits
|
||||
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
||||
|
||||
model.mapper.load_state_dict(mapped_weights)
|
||||
|
||||
# load final layer norm
|
||||
model.final_layer_norm.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
||||
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load final proj
|
||||
model.proj_out.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["proj_out.bias"],
|
||||
"weight": checkpoint["proj_out.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load uncond vector
|
||||
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
||||
return model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
for key in keys:
|
||||
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
continue
|
||||
if key in textenc_conversion_map:
|
||||
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
||||
if key.startswith("cond_stage_model.model.transformer."):
|
||||
new_key = key[len("cond_stage_model.model.transformer.") :]
|
||||
if new_key.endswith(".in_proj_weight"):
|
||||
new_key = new_key[: -len(".in_proj_weight")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
||||
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
||||
elif new_key.endswith(".in_proj_bias"):
|
||||
new_key = new_key[: -len(".in_proj_bias")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
||||
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
||||
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
||||
else:
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
dump_path:str,
|
||||
original_config_file:str=None,
|
||||
num_in_channels:int=None,
|
||||
scheduler_type:str='pndm',
|
||||
pipeline_type:str=None,
|
||||
image_size:int=None,
|
||||
prediction_type:str=None,
|
||||
extract_ema:bool=False,
|
||||
upcast_attn:bool=False,
|
||||
):
|
||||
|
||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
if 'state_dict' in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
upcast_attention = False
|
||||
if original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v2-inference-v.yaml')
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
else:
|
||||
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v1-inference.yaml')
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
# as it relies on a brittle global step parameter here
|
||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||
if image_size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
image_size = 512 if global_step == 875000 else 768
|
||||
else:
|
||||
if prediction_type is None:
|
||||
prediction_type = "epsilon"
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
if scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||
)
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
if model_type is None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=1,
|
||||
)
|
@ -12,7 +12,7 @@ from urllib import request, error as ul_error
|
||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
class Concepts(object):
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
'''
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
@ -29,7 +29,7 @@ class Concepts(object):
|
||||
|
||||
def list_concepts(self)->list:
|
||||
'''
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Also adds local concepts in invokeai/embeddings folder.
|
||||
'''
|
||||
local_concepts_now = self.get_local_concepts(os.path.join(self.root, 'embeddings'))
|
||||
@ -71,11 +71,11 @@ class Concepts(object):
|
||||
if concept_name in self.triggers:
|
||||
return self.triggers[concept_name]
|
||||
elif self.concept_is_local(concept_name):
|
||||
trigger = f'<{concept_name}>'
|
||||
trigger = f'<{concept_name}>'
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
|
||||
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
|
||||
if not file:
|
||||
return None
|
||||
@ -135,7 +135,7 @@ class Concepts(object):
|
||||
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
|
||||
if not (self.concept_is_downloaded(concept_name) or self.concept_is_local(concept_name) or local_only):
|
||||
self.download_concept(concept_name)
|
||||
|
||||
|
||||
# get local path in invokeai/embeddings if local concept
|
||||
if self.concept_is_local(concept_name):
|
||||
concept_path = self._concept_local_path(concept_name)
|
||||
@ -144,7 +144,7 @@ class Concepts(object):
|
||||
concept_path = self._concept_path(concept_name)
|
||||
path = os.path.join(concept_path, file_name)
|
||||
return path if os.path.exists(path) else None
|
||||
|
||||
|
||||
def concept_is_local(self, concept_name)->bool:
|
||||
return concept_name in self.local_concepts
|
||||
|
||||
@ -197,7 +197,7 @@ class Concepts(object):
|
||||
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
|
||||
|
||||
def _concept_local_path(self, concept_name:str)->str:
|
||||
filename = self.local_concepts[concept_name]
|
||||
filename = self.local_concepts[concept_name]
|
||||
return os.path.join(self.root,'embeddings',filename)
|
||||
|
||||
def get_local_concepts(self, loc_dir:str):
|
||||
|
@ -16,9 +16,15 @@ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
from ..models.diffusion import cross_attention_control
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||
@ -216,7 +222,7 @@ def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
||||
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
return conditioning
|
||||
@ -238,7 +244,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
|
||||
|
||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if Globals.always_use_cpu:
|
||||
return "cpu"
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
|
@ -2,26 +2,37 @@
|
||||
Base class for ldm.invoke.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import traceback
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFilter, ImageChops
|
||||
import cv2 as cv
|
||||
from einops import rearrange, repeat
|
||||
from diffusers import DiffusionPipeline
|
||||
from einops import rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from tqdm import trange
|
||||
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'assets/caution.png'
|
||||
|
||||
class Generator():
|
||||
def __init__(self, model, precision):
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
precision: str
|
||||
model: DiffusionWrapper | DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
@ -52,7 +63,6 @@ class Generator():
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
@ -165,7 +175,7 @@ class Generator():
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
@ -177,8 +187,6 @@ class Generator():
|
||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(self,samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
632
ldm/invoke/generator/diffusers_pipeline.py
Normal file
632
ldm/invoke/generator/diffusers_pipeline.py
Normal file
@ -0,0 +1,632 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import secrets
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from typing_extensions import ParamSpec
|
||||
else:
|
||||
from typing import ParamSpec
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import attention
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...models.diffusion import cross_attention_control
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
|
||||
# monkeypatch diffusers CrossAttention 🙈
|
||||
# this is to make prompt2prompt and (future) attention maps work
|
||||
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
run_id: str
|
||||
step: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
# copied from configs/stable-diffusion/v1-inference.yaml
|
||||
_default_personalization_config_params = dict(
|
||||
placeholder_strings=["*"],
|
||||
initializer_wods=["sculpture"],
|
||||
per_image_tokens=False,
|
||||
num_vectors_per_token=1,
|
||||
progressive_words=False
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return (
|
||||
isinstance(b, torch.Tensor)
|
||||
and (a.size() == b.size())
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
scheduler: SchedulerMixin
|
||||
noise: torch.Tensor
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
|
||||
# like pred_original_sample? Should we apply the mask to them too?
|
||||
# But what if there's just some other random field?
|
||||
prev_sample = step_output[0]
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{k: (self.apply_mask(v, self._t_for_field(k, t))
|
||||
if are_like_tensors(prev_sample, v) else v)
|
||||
for k, v in step_output.items()}
|
||||
)
|
||||
|
||||
def _t_for_field(self, field_name:str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
# some schedulers expect t to be one-dimensional.
|
||||
# TODO: file diffusers bug about inconsistency?
|
||||
t = einops.repeat(t, '-> batch', batch=batch_size)
|
||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||
# get very confused about what is happening from step to step when we do that.
|
||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
if self._debug:
|
||||
self._debug(masked_input, f"t={t} lerped")
|
||||
return masked_input
|
||||
|
||||
|
||||
def trim_to_multiple_of(*args, multiple_of=8):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
:param image: input image
|
||||
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = trim_to_multiple_of(*image.size)
|
||||
transformation = T.Compose([
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
T.ToTensor(),
|
||||
])
|
||||
tensor = transformation(image)
|
||||
if normalize:
|
||||
tensor = tensor * 2.0 - 1.0
|
||||
return tensor
|
||||
|
||||
|
||||
def is_inpainting_model(unet: UNet2DConditionModel):
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
CallbackType = TypeVar('CallbackType')
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
ParamType = ParamSpec('ParamType')
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
"""Convert a generator to a function with a callback and a return value."""
|
||||
|
||||
generator_method: Callable[ParamType, ReturnType]
|
||||
callback_arg_type: Type[CallbackType]
|
||||
|
||||
def __call__(self, *args: ParamType.args,
|
||||
callback:Callable[[CallbackType], Any]=None,
|
||||
**kwargs: ParamType.kwargs) -> ReturnType:
|
||||
result = None
|
||||
for result in self.generator_method(*args, **kwargs):
|
||||
if callback is not None and isinstance(result, self.callback_arg_type):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: torch.Tensor
|
||||
text_embeddings: torch.Tensor
|
||||
guidance_scale: float
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional arguments to pass to scheduler.step."""
|
||||
threshold: Optional[ThresholdSettings] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
|
||||
Hopefully future versions of diffusers provide access to more of these functions so that we don't
|
||||
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
ID_LENGTH = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = 'float32',
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
||||
safety_checker, feature_extractor, requires_safety_checker)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
)
|
||||
|
||||
if is_xformers_available():
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
:param conditioning_data:
|
||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
||||
image generation. Can be used to tweak the same generation with different prompts.
|
||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
||||
:param callback:
|
||||
:param run_id:
|
||||
"""
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents, num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
timesteps=None,
|
||||
additional_guidance: List[Callable] = None, run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents, timesteps, conditioning_data,
|
||||
noise=noise,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None):
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
i, additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index:int | None = None, additional_guidance: List[Callable] = None):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
latent_model_input, t,
|
||||
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
threshold=conditioning_data.threshold
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents,
|
||||
**conditioning_data.scheduler_args)
|
||||
|
||||
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||
# But the way things are now, scheduler runs _after_ that, so there was
|
||||
# no way to use it to apply an operation that happens after the last scheduler.step.
|
||||
for guidance in additional_guidance:
|
||||
step_output = guidance(step_output, timestep, conditioning_data)
|
||||
|
||||
return step_output
|
||||
|
||||
def _unet_forward(self, latents, t, text_embeddings):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||
# FIXME: There are too many layers of functions and we have too many different ways of
|
||||
# overriding things! This should get handled in a way more consistent with the other
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||
).add_mask_channels(latents)
|
||||
|
||||
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
def img2img_from_embeddings(self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
|
||||
# 6. Prepare latent variables
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise, run_id, callback)
|
||||
|
||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor, run_id=None, callback=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
initial_latents, num_inference_steps, conditioning_data,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def inpaint_from_embeddings(
|
||||
self,
|
||||
init_image: torch.FloatTensor,
|
||||
mask: torch.FloatTensor,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
noise = noise_func(init_image_latents)
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
|
||||
guidance: List[Callable] = []
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||
else:
|
||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
init_image_latents, num_inference_steps,
|
||||
conditioning_data, noise=noise, timesteps=timesteps,
|
||||
additional_guidance=guidance,
|
||||
run_id=run_id, callback=callback)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||
output.images, device=self._execution_device, dtype=dtype)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
return InvokeAIStableDiffusionPipelineOutput(screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||
"""
|
||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments(
|
||||
text=c,
|
||||
fragment_weights=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self.device)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
warnings.warn("legacy compatibility layer", DeprecationWarning)
|
||||
return self.prompt_fragments_to_embeddings_converter
|
||||
|
||||
@torch.inference_mode()
|
||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||
return self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
with torch.inference_mode():
|
||||
from ldm.util import debug_image
|
||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
@ -3,14 +3,14 @@ ldm.invoke.generator.embiggen descends from ldm.invoke.generator
|
||||
and generates with ldm.invoke.generator.img2img
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
from PIL import Image
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -22,7 +22,6 @@ class Embiggen(Generator):
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
step_callback = step_callback,
|
||||
@ -32,14 +31,13 @@ class Embiggen(Generator):
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
with scope(self.model.device.type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
for _ in trange(iterations, desc='Generating'):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
@ -353,7 +351,7 @@ class Embiggen(Generator):
|
||||
prompt,
|
||||
iterations = 1,
|
||||
seed = seed,
|
||||
sampler = DDIMSampler(self.model, device=self.model.device),
|
||||
sampler = sampler,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = conditioning,
|
||||
@ -493,7 +491,7 @@ class Embiggen(Generator):
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||
print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
|
@ -3,14 +3,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import PIL
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from diffusers import logging
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -18,80 +16,69 @@ class Img2Img(Generator):
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc - 1]).to(self.model.device),
|
||||
noise=x_T
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
logging.set_verbosity_error() # quench safety check warnings
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image, strength, steps, conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
init_latent = self.init_latent
|
||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(init_latent, device='cpu').to(device)
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn_like(init_latent, device=device)
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = init_latent.shape
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.model.device)
|
||||
def get_noise(self,width,height):
|
||||
# copy of the Txt2Img.get_noise
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
@ -1,24 +1,22 @@
|
||||
'''
|
||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
|
||||
import PIL
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
|
||||
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.patchmatch import PatchMatch
|
||||
from ldm.util import debug_image
|
||||
from ldm.invoke.patchmatch import PatchMatch
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
|
||||
def infill_methods()->list[str]:
|
||||
methods = list()
|
||||
@ -29,6 +27,9 @@ def infill_methods()->list[str]:
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.inpaint_height = 0
|
||||
self.inpaint_width = 0
|
||||
self.enable_image_debugging = False
|
||||
self.init_latent = None
|
||||
self.pil_image = None
|
||||
self.pil_mask = None
|
||||
@ -117,13 +118,13 @@ class Inpaint(Img2Img):
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||
|
||||
# Combine
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
@ -133,15 +134,8 @@ class Inpaint(Img2Img):
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self,
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,strength,
|
||||
noise,
|
||||
step_callback
|
||||
) -> Image.Image:
|
||||
def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta,
|
||||
conditioning, strength, noise, infill_method, step_callback) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
@ -153,13 +147,14 @@ class Inpaint(Img2Img):
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||
mask_image = mask,
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0,
|
||||
step_callback = step_callback,
|
||||
inpaint_width = im.width,
|
||||
inpaint_height = im.height
|
||||
inpaint_height = im.height,
|
||||
infill_method = infill_method
|
||||
)
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
@ -171,7 +166,10 @@ class Inpaint(Img2Img):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
@ -184,6 +182,7 @@ class Inpaint(Img2Img):
|
||||
infill_method = None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
@ -193,7 +192,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
self.enable_image_debugging = enable_image_debugging
|
||||
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
|
||||
|
||||
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
@ -218,13 +217,17 @@ class Inpaint(Img2Img):
|
||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Create init tensor
|
||||
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB'))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
|
||||
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
|
||||
init_alpha = self.pil_image.getchannel("A")
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_image = ImageChops.multiply(mask_image, init_alpha)
|
||||
self.pil_mask = mask_image
|
||||
|
||||
# Resize if requested for inpainting
|
||||
@ -232,95 +235,45 @@ class Inpaint(Img2Img):
|
||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling,
|
||||
mask_image.height // downsampling
|
||||
),
|
||||
resample=Image.Resampling.NEAREST
|
||||
)
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
else:
|
||||
mask: torch.FloatTensor = mask_image
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
if isinstance(sampler,KSampler):
|
||||
print(
|
||||
f">> Using recommended DDIM sampler for inpainting."
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
conditioning_data = (ConditioningData(uc, c, cfg_scale)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc - 1]).to(self.model.device),
|
||||
noise=x_T
|
||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||
init_image=init_image,
|
||||
mask=1 - mask, # expects white means "paint here."
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
||||
if inpaint_replace > 0.0:
|
||||
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
||||
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
||||
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
||||
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
||||
z_enc = z_enc * mask_image + masked_region
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
|
||||
result = self.sample_to_image(samples)
|
||||
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
old_image = self.pil_image or init_image
|
||||
old_mask = self.pil_mask or mask_image
|
||||
|
||||
result = self.seam_paint(
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
seam_strength,
|
||||
x_T,
|
||||
step_callback)
|
||||
result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta,
|
||||
conditioning, seam_strength, x_T, infill_method, step_callback)
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
@ -343,6 +296,10 @@ class Inpaint(Img2Img):
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
return self.postprocess_size_and_mask(gen_result)
|
||||
|
||||
|
||||
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
|
||||
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Resize if necessary
|
||||
@ -352,7 +309,7 @@ class Inpaint(Img2Img):
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
corrected_result = self.repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
return corrected_result
|
||||
|
@ -1,14 +1,14 @@
|
||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
from einops import repeat
|
||||
from PIL import Image, ImageOps, ImageChops
|
||||
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
|
||||
|
||||
class Omnibus(Img2Img,Txt2Img):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@ -40,6 +40,8 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
print('DEBUG: IN OMNIBUS')
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
@ -58,8 +60,6 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
t_enc = steps
|
||||
|
||||
if init_image is not None and mask_image is not None: # inpainting
|
||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
'''
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
import gc
|
||||
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||
from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
@ -24,45 +24,30 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to('cpu')
|
||||
self.model.cond_stage_model.device = 'cpu'
|
||||
self.model.cond_stage_model.to('cpu')
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
@ -70,15 +55,17 @@ class Txt2Img(Generator):
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
@ -2,67 +2,55 @@
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
import gc
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from PIL import Image
|
||||
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,strength,step_callback=None,**kwargs):
|
||||
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||
conditioning, width:int, height:int, strength:float,
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
scale_dim = min(width, height)
|
||||
scale = 512 / scale_dim
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
init_width, init_height = trim_to_multiple_of(scale * width, scale * height)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
init_height // self.downsampling_factor,
|
||||
init_width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
#x = self.get_noise(init_width, init_height)
|
||||
x = x_T
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
# TODO: threshold = threshold,
|
||||
)
|
||||
|
||||
print(
|
||||
@ -70,88 +58,45 @@ class Txt2Img2Img(Generator):
|
||||
)
|
||||
|
||||
# resizing
|
||||
samples = torch.nn.functional.interpolate(
|
||||
samples,
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
ddim_sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
second_pass_noise = self.get_noise_like(resized_latents)
|
||||
|
||||
z_enc = ddim_sampler.stochastic_encode(
|
||||
samples,
|
||||
torch.tensor([t_enc-1]).to(self.model.device),
|
||||
noise=self.get_noise(width,height,False)
|
||||
)
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback)
|
||||
|
||||
# decode it
|
||||
samples = ddim_sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to('cpu')
|
||||
self.model.cond_stage_model.device = 'cpu'
|
||||
self.model.cond_stage_model.to('cpu')
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
def inpaint_make_image(x_T):
|
||||
omnibus = Omnibus(self.model,self.precision)
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
width=init_width,
|
||||
height=init_height,
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||
image = result[0][0]
|
||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||
print(kwargs.pop('init_image',None))
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=interpolated_image,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=result[0][1],
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
return result[0][0]
|
||||
|
||||
if sampler.uses_inpainting_model():
|
||||
return inpaint_make_image
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
else:
|
||||
return make_image
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
@ -179,4 +124,3 @@ class Txt2Img2Img(Generator):
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
||||
|
@ -8,11 +8,14 @@ the attributes:
|
||||
- root - the root directory under which "models" and "outputs" can be found
|
||||
- initfile - path to the initialization file
|
||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||
- always_use_cpu - force use of CPU even if GPU is available
|
||||
'''
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
from argparse import Namespace
|
||||
from typing import Union
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
@ -26,6 +29,41 @@ else:
|
||||
|
||||
# Where to look for the initialization file
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
|
||||
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
||||
Globals.always_use_cpu = False
|
||||
|
||||
# Whether the internet is reachable for dynamic downloads
|
||||
# The CLI will test connectivity at startup time.
|
||||
Globals.internet_available = True
|
||||
|
||||
def global_config_dir()->Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
def global_models_dir()->Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
def global_autoscan_dir()->Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
def global_set_root(root_dir:Union[str,Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
||||
'''
|
||||
Returns Path to the model cache directory. If a subdirectory
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for huggingface-style conventions:
|
||||
global_cache_dir('diffusers')
|
||||
global_cache_dir('transformers')
|
||||
'''
|
||||
if (home := os.environ.get('HF_HOME')):
|
||||
return Path(home,subdir)
|
||||
else:
|
||||
return Path(Globals.root,'models',subdir)
|
||||
|
@ -1,451 +0,0 @@
|
||||
'''
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import gc
|
||||
import hashlib
|
||||
import psutil
|
||||
import sys
|
||||
import transformers
|
||||
import traceback
|
||||
import textwrap
|
||||
import contextlib
|
||||
from typing import Union
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config, ask_user
|
||||
from ldm.invoke.globals import Globals
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
'''
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.max_loaded_models = max_loaded_models
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def valid_model(self, model_name:str)->bool:
|
||||
'''
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
'''
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if not self.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return self.current_model
|
||||
|
||||
if self.current_model != model_name:
|
||||
if model_name not in self.models: # make room for a new one
|
||||
self._make_cache_room()
|
||||
self.offload_model(self.current_model)
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {
|
||||
'model': requested_model,
|
||||
'width': width,
|
||||
'height': height,
|
||||
'hash': hash,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(traceback.format_exc())
|
||||
assert self.current_model,'** FATAL: no current model to restore to'
|
||||
print(f'** restoring {self.current_model}')
|
||||
self.get_model(self.current_model)
|
||||
return
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
'model':requested_model,
|
||||
'width':width,
|
||||
'height':height,
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
'''
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get('default'):
|
||||
return model_name
|
||||
|
||||
def set_default_model(self,model_name:str) -> None:
|
||||
'''
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_cache.commit()
|
||||
'''
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
|
||||
config = self.config
|
||||
for model in config:
|
||||
config[model].pop('default',None)
|
||||
config[model_name]['default'] = True
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
},
|
||||
model_name2: { etc }
|
||||
'''
|
||||
models = {}
|
||||
for name in self.config:
|
||||
description = self.config[name].description if 'description' in self.config[name] else '<no description>'
|
||||
weights = self.config[name].weights if 'weights' in self.config[name] else '<no weights>'
|
||||
config = self.config[name].config if 'config' in self.config[name] else '<no config>'
|
||||
width = self.config[name].width if 'width' in self.config[name] else 512
|
||||
height = self.config[name].height if 'height' in self.config[name] else 512
|
||||
default = self.config[name].default if 'default' in self.config[name] else False
|
||||
vae = self.config[name].vae if 'vae' in self.config[name] else '<no vae>'
|
||||
|
||||
if self.current_model == name:
|
||||
status = 'active'
|
||||
elif name in self.models:
|
||||
status = 'cached'
|
||||
else:
|
||||
status = 'not loaded'
|
||||
|
||||
models[name]={
|
||||
'status' : status,
|
||||
'description' : description,
|
||||
'weights': weights,
|
||||
'config': config,
|
||||
'width': width,
|
||||
'height': height,
|
||||
'vae': vae,
|
||||
'default': default
|
||||
}
|
||||
return models
|
||||
|
||||
def print_models(self) -> None:
|
||||
'''
|
||||
Print a table of models, their descriptions, and load status
|
||||
'''
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
||||
if models[name]['status'] == 'active':
|
||||
line = f'\033[1m{line}\033[0m'
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name:str) -> None:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
omega = self.config
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) -> None:
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
'''
|
||||
omega = self.config
|
||||
for field in ('description','weights','height','width','config'):
|
||||
assert field in model_attributes, f'required field {field} is missing'
|
||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||
|
||||
config = omega[model_name] if model_name in omega else {}
|
||||
for field in model_attributes:
|
||||
if field == 'weights':
|
||||
field.replace('\\', '/')
|
||||
config[field] = model_attributes[field]
|
||||
|
||||
omega[model_name] = config
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae')
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||
# scan model
|
||||
self.scan_model(model_name, weights)
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
omega_config = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
# merged models from auto11 merge board are flat for some reason
|
||||
if 'state_dict' in sd:
|
||||
sd = sd['state_dict']
|
||||
|
||||
print(f' | Forcing garbage collection prior to loading new model')
|
||||
gc.collect()
|
||||
model = instantiate_from_config(omega_config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae:
|
||||
if not os.path.isabs(vae):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
model.eval()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str) -> None:
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
'''
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
print(f'>> Offloading {model_name} to CPU')
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
# scan model
|
||||
print(f'>> Scanning Model: {model_name}')
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
|
||||
print('### WARNING: The model you are trying to load seems to be infected.')
|
||||
print('### For your safety, InvokeAI will not load this model.')
|
||||
print('### Please use checkpoints from trusted sources.')
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
||||
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
||||
if model_safe_check_fail.lower() != 'y':
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print('>> Model Scanned. OK!!')
|
||||
|
||||
def search_models(self, search_folder):
|
||||
|
||||
print(f'>> Finding Models In: {search_folder}')
|
||||
models_folder = Path(search_folder).glob('**/*.ckpt')
|
||||
|
||||
files = [x for x in models_folder if x.is_file()]
|
||||
|
||||
found_models = []
|
||||
for file in files:
|
||||
found_models.append({
|
||||
'name': file.stem,
|
||||
'location': str(file.resolve()).replace('\\', '/')
|
||||
})
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def _make_cache_room(self) -> None:
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
def commit(self,config_file_path:str) -> None:
|
||||
'''
|
||||
Write current configuration out to the indicated file.
|
||||
'''
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
if not os.path.isabs(config_file_path):
|
||||
config_file_path = os.path.normpath(os.path.join(Globals.root,opt.conf))
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.replace(tmpfile,config_file_path)
|
||||
|
||||
def preamble(self) -> str:
|
||||
'''
|
||||
Returns the preamble for the config file.
|
||||
'''
|
||||
return textwrap.dedent('''\
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
''')
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str) -> None:
|
||||
self.offload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
model.first_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.model.to('cpu')
|
||||
return model.to('cpu')
|
||||
else:
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.to(self.device)
|
||||
model.first_stage_model.to(self.device)
|
||||
model.cond_stage_model.to(self.device)
|
||||
model.cond_stage_model.device = self.device
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
'''
|
||||
Remove the first element of the FIFO, which ought
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
'''
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self,model_name:str) -> None:
|
||||
'''
|
||||
Maintain a simple FIFO. First element is always the
|
||||
least recent, and last element is always the most recent.
|
||||
'''
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(model_name)
|
||||
self.stack.append(model_name)
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _cached_sha256(self,path,data) -> Union[str, bytes]:
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
953
ldm/invoke/model_manager.py
Normal file
953
ldm/invoke/model_manager.py
Normal file
@ -0,0 +1,953 @@
|
||||
'''
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
from typing import Union, Any
|
||||
from ldm.util import download_with_progress_bar
|
||||
|
||||
import torch
|
||||
import safetensors
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL, logging as dlogging
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_models_dir, global_autoscan_dir, global_cache_dir
|
||||
from ldm.util import instantiate_from_config, ask_user
|
||||
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
class ModelManager(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
'''
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.max_loaded_models = max_loaded_models
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def valid_model(self, model_name:str)->bool:
|
||||
'''
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
'''
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if not self.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return self.current_model
|
||||
|
||||
if self.current_model != model_name:
|
||||
if model_name not in self.models: # make room for a new one
|
||||
self._make_cache_room()
|
||||
self.offload_model(self.current_model)
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {
|
||||
'model': requested_model,
|
||||
'width': width,
|
||||
'height': height,
|
||||
'hash': hash,
|
||||
}
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
'model':requested_model,
|
||||
'width':width,
|
||||
'height':height,
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str | None:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
'''
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get('default'):
|
||||
return model_name
|
||||
|
||||
def set_default_model(self,model_name:str) -> None:
|
||||
'''
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_manager.commit()
|
||||
'''
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
|
||||
config = self.config
|
||||
for model in config:
|
||||
config[model].pop('default',None)
|
||||
config[model_name]['default'] = True
|
||||
|
||||
def model_info(self, model_name:str)->dict:
|
||||
'''
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
'''
|
||||
if model_name not in self.config:
|
||||
return None
|
||||
return self.config[model_name]
|
||||
|
||||
def model_names(self)->list[str]:
|
||||
'''
|
||||
Return a list consisting of all the names of models defined in models.yaml
|
||||
'''
|
||||
return list(self.config.keys())
|
||||
|
||||
def is_legacy(self,model_name:str)->bool:
|
||||
'''
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
'''
|
||||
info = self.model_info(model_name)
|
||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
'format': ('ckpt'|'diffusers'|'vae'),
|
||||
},
|
||||
model_name2: { etc }
|
||||
Please use model_manager.models() to get all the model names,
|
||||
model_manager.model_info('model-name') to get the stanza for the model
|
||||
named 'model-name', and model_manager.config to get the full OmegaConf
|
||||
object derived from models.yaml
|
||||
'''
|
||||
models = {}
|
||||
for name in sorted(self.config):
|
||||
stanza = self.config[name]
|
||||
|
||||
# don't include VAEs in listing (legacy style)
|
||||
if 'config' in stanza and '/VAE/' in stanza['config']:
|
||||
continue
|
||||
|
||||
models[name] = dict()
|
||||
format = stanza.get('format','ckpt') # Determine Format
|
||||
|
||||
# Common Attribs
|
||||
description = stanza.get('description', None)
|
||||
if self.current_model == name:
|
||||
status = 'active'
|
||||
elif name in self.models:
|
||||
status = 'cached'
|
||||
else:
|
||||
status = 'not loaded'
|
||||
models[name].update(
|
||||
description = description,
|
||||
format = format,
|
||||
status = status,
|
||||
)
|
||||
|
||||
# Checkpoint Config Parse
|
||||
if format == 'ckpt':
|
||||
models[name].update(
|
||||
config = str(stanza.get('config', None)),
|
||||
weights = str(stanza.get('weights', None)),
|
||||
vae = str(stanza.get('vae', None)),
|
||||
width = str(stanza.get('width', 512)),
|
||||
height = str(stanza.get('height', 512)),
|
||||
)
|
||||
|
||||
# Diffusers Config Parse
|
||||
if (vae := stanza.get('vae',None)):
|
||||
if isinstance(vae,DictConfig):
|
||||
vae = dict(
|
||||
repo_id = str(vae.get('repo_id',None)),
|
||||
path = str(vae.get('path',None)),
|
||||
subfolder = str(vae.get('subfolder',None))
|
||||
)
|
||||
|
||||
if format == 'diffusers':
|
||||
models[name].update(
|
||||
vae = vae,
|
||||
repo_id = str(stanza.get('repo_id', None)),
|
||||
path = str(stanza.get('path',None)),
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
def print_models(self) -> None:
|
||||
'''
|
||||
Print a table of models, their descriptions, and load status
|
||||
'''
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
if models[name]['format'] == 'vae':
|
||||
continue
|
||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["format"]:10s} {models[name]["description"]}'
|
||||
if models[name]['status'] == 'active':
|
||||
line = f'\033[1m{line}\033[0m'
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name:str) -> None:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
omega = self.config
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None:
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
'''
|
||||
omega = self.config
|
||||
assert 'format' in model_attributes, 'missing required field "format"'
|
||||
if model_attributes['format']=='diffusers':
|
||||
assert 'description' in model_attributes, 'required field "description" is missing'
|
||||
assert 'path' in model_attributes or 'repo_id' in model_attributes,'model must have either the "path" or "repo_id" fields defined'
|
||||
else:
|
||||
for field in ('description','weights','height','width','config'):
|
||||
assert field in model_attributes, f'required field {field} is missing'
|
||||
|
||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||
|
||||
if model_name not in omega:
|
||||
omega[model_name] = dict()
|
||||
OmegaConf.update(omega,model_name,model_attributes,merge=False)
|
||||
if 'weights' in omega[model_name]:
|
||||
omega[model_name]['weights'].replace('\\','/')
|
||||
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
model_format = mconfig.get('format', 'ckpt')
|
||||
if model_format == 'ckpt':
|
||||
weights = mconfig.weights
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
model, width, height, model_hash = self._load_ckpt_model(model_name, mconfig)
|
||||
elif model_format == 'diffusers':
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print('>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae')
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||
# scan model
|
||||
self.scan_model(model_name, weights)
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
omega_config = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||
sd = None
|
||||
if weights.endswith('.safetensors'):
|
||||
sd = safetensors.torch.load(weight_bytes)
|
||||
else:
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
# merged models from auto11 merge board are flat for some reason
|
||||
if 'state_dict' in sd:
|
||||
sd = sd['state_dict']
|
||||
|
||||
print(' | Forcing garbage collection prior to loading new model')
|
||||
gc.collect()
|
||||
model = instantiate_from_config(omega_config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae:
|
||||
if not os.path.isabs(vae):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
model.eval()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print('>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == 'float16'
|
||||
|
||||
print(f'>> Loading diffusers model from {name_or_path}')
|
||||
if using_fp16:
|
||||
print(' | Using faster float16 precision')
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
safety_checker=None,
|
||||
local_files_only=not Globals.internet_available
|
||||
)
|
||||
if 'vae' in mconfig:
|
||||
vae = self._load_vae(mconfig['vae'])
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path,Path):
|
||||
pipeline_args.update(cache_dir=global_cache_dir('diffusers'))
|
||||
if using_fp16:
|
||||
pipeline_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{'revision':'fp16'},{}]
|
||||
else:
|
||||
fp_args_list = [{}]
|
||||
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipeline = None
|
||||
for fp_args in fp_args_list:
|
||||
try:
|
||||
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||
name_or_path,
|
||||
**pipeline_args,
|
||||
**fp_args,
|
||||
)
|
||||
|
||||
except OSError as e:
|
||||
if str(e).startswith('fp16 is not a valid'):
|
||||
print(f'Could not fetch half-precision version of model {name_or_path}; fetching full-precision instead')
|
||||
else:
|
||||
print(f'An unexpected error occurred while downloading the model: {e})')
|
||||
if pipeline:
|
||||
break
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
|
||||
|
||||
pipeline.to(self.device)
|
||||
|
||||
model_hash = self._diffuser_sha256(name_or_path)
|
||||
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
|
||||
print(f' | Default image dimensions = {width} x {height}')
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def model_name_or_path(self, model_name:Union[str,DictConfig]) -> str | Path:
|
||||
if isinstance(model_name,DictConfig):
|
||||
mconfig = model_name
|
||||
elif model_name in self.config:
|
||||
mconfig = self.config[model_name]
|
||||
else:
|
||||
raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
|
||||
if 'path' in mconfig:
|
||||
path = Path(mconfig['path'])
|
||||
if not path.is_absolute():
|
||||
path = Path(Globals.root, path).resolve()
|
||||
return path
|
||||
elif 'repo_id' in mconfig:
|
||||
return mconfig['repo_id']
|
||||
else:
|
||||
raise ValueError("Model config must specify either repo_id or path.")
|
||||
|
||||
def offload_model(self, model_name:str) -> None:
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
'''
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
print(f'>> Offloading {model_name} to CPU')
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
'''
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
'''
|
||||
# scan model
|
||||
print(f'>> Scanning Model: {model_name}')
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
|
||||
print('### WARNING: The model you are trying to load seems to be infected.')
|
||||
print('### For your safety, InvokeAI will not load this model.')
|
||||
print('### Please use checkpoints from trusted sources.')
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
||||
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
||||
if model_safe_check_fail.lower() != 'y':
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print('>> Model scanned ok!')
|
||||
|
||||
def import_diffuser_model(self,
|
||||
repo_or_path:Union[str,Path],
|
||||
model_name:str=None,
|
||||
description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
'''
|
||||
Attempts to install the indicated diffuser model and returns True if successful.
|
||||
|
||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
||||
top of a downloaded diffusers directory.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the repo name. If you provide a commit_to_conf
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
'''
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
description = description or f'imported diffusers model {model_name}'
|
||||
new_config = dict(
|
||||
description=description,
|
||||
format='diffusers',
|
||||
)
|
||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||
new_config.update(path=repo_or_path)
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return True
|
||||
|
||||
def import_ckpt_model(self,
|
||||
weights:Union[str,Path],
|
||||
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
|
||||
model_name:str=None,
|
||||
model_description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
'''
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||
or a http/https URL pointing to a remote model.
|
||||
|
||||
"config" is the model config file to use with this ckpt file. It defaults to
|
||||
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
'''
|
||||
weights_path = self._resolve_path(weights,'models/ldm/stable-diffusion-v1')
|
||||
config_path = self._resolve_path(config,'configs/stable-diffusion')
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return False
|
||||
if config_path is None or not config_path.exists():
|
||||
return False
|
||||
|
||||
model_name = model_name or Path(weights).stem
|
||||
model_description = model_description or f'imported stable diffusion weights file {model_name}'
|
||||
new_config = dict(
|
||||
weights=str(weights_path),
|
||||
config=str(config_path),
|
||||
description=model_description,
|
||||
format='ckpt',
|
||||
width=512,
|
||||
height=512
|
||||
)
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return True
|
||||
|
||||
def autoconvert_weights(
|
||||
self,
|
||||
conf_path:Path,
|
||||
weights_directory:Path=None,
|
||||
dest_directory:Path=None,
|
||||
):
|
||||
'''
|
||||
Scan the indicated directory for .ckpt files, convert into diffuser models,
|
||||
and import.
|
||||
'''
|
||||
weights_directory = weights_directory or global_autoscan_dir()
|
||||
dest_directory = dest_directory or Path(global_models_dir(), 'optimized-ckpts')
|
||||
|
||||
print('>> Checking for unconverted .ckpt files in {weights_directory}')
|
||||
ckpt_files = dict()
|
||||
for root, dirs, files in os.walk(weights_directory):
|
||||
for f in files:
|
||||
if not f.endswith('.ckpt'):
|
||||
continue
|
||||
basename = Path(f).stem
|
||||
dest = Path(dest_directory,basename)
|
||||
if not dest.exists():
|
||||
ckpt_files[Path(root,f)]=dest
|
||||
|
||||
if len(ckpt_files)==0:
|
||||
return
|
||||
|
||||
print(f'>> New .ckpt file(s) found in {weights_directory}. Optimizing and importing...')
|
||||
for ckpt in ckpt_files:
|
||||
self.convert_and_import(ckpt, ckpt_files[ckpt])
|
||||
self.commit(conf_path)
|
||||
|
||||
def convert_and_import(self,
|
||||
ckpt_path:Path,
|
||||
diffuser_path:Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->dict:
|
||||
'''
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
into models.yaml.
|
||||
'''
|
||||
new_config = None
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
import transformers
|
||||
if diffuser_path.exists():
|
||||
print(f'ERROR: The path {str(diffuser_path)} already exists. Please move or remove it and try again.')
|
||||
return
|
||||
|
||||
model_name = model_name or diffuser_path.name
|
||||
model_description = model_description or 'Optimized version of {model_name}'
|
||||
print(f'>> {model_name}: optimizing (30-60s).')
|
||||
try:
|
||||
verbosity =transformers.logging.get_verbosity()
|
||||
transformers.logging.set_verbosity_error()
|
||||
convert_ckpt_to_diffuser(ckpt_path, diffuser_path,extract_ema=True)
|
||||
transformers.logging.set_verbosity(verbosity)
|
||||
print(f'>> Success. Optimized model is now located at {str(diffuser_path)}')
|
||||
print(f'>> Writing new config file entry for {model_name}...',end='')
|
||||
new_config = dict(
|
||||
path=str(diffuser_path),
|
||||
description=model_description,
|
||||
format='diffusers',
|
||||
)
|
||||
self.del_model(model_name)
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
except Exception as e:
|
||||
print(f'** Conversion failed: {str(e)}')
|
||||
traceback.print_exc()
|
||||
|
||||
print('done.')
|
||||
return new_config
|
||||
|
||||
def del_config(self, model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
gen.model_manager.del_model(model_name)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f'>> Finding Models In: {search_folder}')
|
||||
models_folder_ckpt = Path(search_folder).glob('**/*.ckpt')
|
||||
models_folder_safetensors = Path(search_folder).glob('**/*.safetensors')
|
||||
|
||||
ckpt_files = [x for x in models_folder_ckpt if x.is_file()]
|
||||
safetensor_files = [x for x in models_folder_safetensors if x.is_file]
|
||||
|
||||
files = ckpt_files + safetensor_files
|
||||
|
||||
found_models = []
|
||||
for file in files:
|
||||
found_models.append({
|
||||
'name': file.stem,
|
||||
'location': str(file.resolve()).replace('\\', '/')
|
||||
})
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def _make_cache_room(self) -> None:
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
def commit(self,config_file_path:str) -> None:
|
||||
'''
|
||||
Write current configuration out to the indicated file.
|
||||
'''
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
if not os.path.isabs(config_file_path):
|
||||
config_file_path = os.path.normpath(os.path.join(Globals.root,config_file_path))
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w', encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.replace(tmpfile,config_file_path)
|
||||
|
||||
def preamble(self) -> str:
|
||||
'''
|
||||
Returns the preamble for the config file.
|
||||
'''
|
||||
return textwrap.dedent('''\
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
''')
|
||||
|
||||
@classmethod
|
||||
def migrate_models(cls):
|
||||
'''
|
||||
Migrate the ~/invokeai/models directory from the legacy format used through 2.2.5
|
||||
to the 2.3.0 "diffusers" version. This should be a one-time operation, called at
|
||||
script startup time.
|
||||
'''
|
||||
# Three transformer models to check: bert, clip and safety checker
|
||||
legacy_locations = [
|
||||
Path('CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker'),
|
||||
Path('bert-base-uncased/models--bert-base-uncased'),
|
||||
Path('openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14')
|
||||
]
|
||||
models_dir = Path(Globals.root,'models')
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or Path(models_dir,model).exists()
|
||||
if not legacy_layout:
|
||||
return
|
||||
|
||||
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
|
||||
print('** This is a quick one-time operation.')
|
||||
from shutil import move
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
hub = models_dir / 'hub'
|
||||
os.makedirs(hub, exist_ok=True)
|
||||
for model in legacy_locations:
|
||||
source = models_dir /model
|
||||
if source.exists():
|
||||
print(f'DEBUG: Moving {models_dir / model} into hub')
|
||||
move(models_dir / model, hub)
|
||||
|
||||
# anything else gets moved into the diffusers directory
|
||||
diffusers = models_dir / 'diffusers'
|
||||
os.makedirs(diffusers, exist_ok=True)
|
||||
for root, dirs, _ in os.walk(models_dir, topdown=False):
|
||||
for dir in dirs:
|
||||
full_path = Path(root,dir)
|
||||
if full_path.is_relative_to(hub) or full_path.is_relative_to(diffusers):
|
||||
continue
|
||||
if Path(dir).match('models--*--*'):
|
||||
move(full_path,diffusers)
|
||||
|
||||
# now clean up by removing any empty directories
|
||||
empty = [root for root, dirs, files, in os.walk(models_dir) if not len(dirs) and not len(files)]
|
||||
for d in empty:
|
||||
os.rmdir(d)
|
||||
print('** Migration is done. Continuing...')
|
||||
|
||||
|
||||
def _resolve_path(self, source:Union[str,Path], dest_directory:str)->Path:
|
||||
resolved_path = None
|
||||
if source.startswith(('http:','https:','ftp:')):
|
||||
basename = os.path.basename(source)
|
||||
if not os.path.isabs(dest_directory):
|
||||
dest_directory = os.path.join(Globals.root,dest_directory)
|
||||
dest = os.path.join(dest_directory,basename)
|
||||
if download_with_progress_bar(source,dest):
|
||||
resolved_path = Path(dest)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root,source)
|
||||
resolved_path = Path(source)
|
||||
return resolved_path
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str) -> None:
|
||||
self.offload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device == 'cpu':
|
||||
return model
|
||||
|
||||
# diffusers really really doesn't like us moving a float16 model onto CPU
|
||||
import logging
|
||||
logging.getLogger('diffusers.pipeline_utils').setLevel(logging.CRITICAL)
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
model.to('cpu')
|
||||
logging.getLogger('pipeline_utils').setLevel(logging.INFO)
|
||||
|
||||
for submodel in ('first_stage_model','cond_stage_model','model'):
|
||||
try:
|
||||
getattr(model,submodel).to('cpu')
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self,model):
|
||||
if self.device == 'cpu':
|
||||
return model
|
||||
|
||||
model.to(self.device)
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
for submodel in ('first_stage_model','cond_stage_model','model'):
|
||||
try:
|
||||
getattr(model,submodel).to(self.device)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
'''
|
||||
Remove the first element of the FIFO, which ought
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
'''
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self,model_name:str) -> None:
|
||||
'''
|
||||
Maintain a simple FIFO. First element is always the
|
||||
least recent, and last element is always the most recent.
|
||||
'''
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(model_name)
|
||||
self.stack.append(model_name)
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _diffuser_sha256(self,name_or_path:Union[str, Path])->Union[str,bytes]:
|
||||
path = None
|
||||
if isinstance(name_or_path,Path):
|
||||
path = name_or_path
|
||||
else:
|
||||
owner,repo = name_or_path.split('/')
|
||||
path = Path(global_cache_dir('diffusers') / f'models--{owner}--{repo}')
|
||||
if not path.exists():
|
||||
return None
|
||||
hashpath = path / 'checksum.sha256'
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(' | Calculating sha256 hash of model files')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
for root, dirs, files in os.walk(path, followlinks=False):
|
||||
for name in files:
|
||||
count += 1
|
||||
with open(os.path.join(root,name),'rb') as f:
|
||||
sha.update(f.read())
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f' | sha256 = {hash} ({count} files hashed in','%4.2fs)' % (toc - tic))
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _cached_sha256(self,path,data) -> Union[str, bytes]:
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(' | Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _load_vae(self, vae_config):
|
||||
vae_args = {}
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
using_fp16 = self.precision == 'float16'
|
||||
|
||||
vae_args.update(
|
||||
cache_dir=global_cache_dir('diffusers'),
|
||||
local_files_only=not Globals.internet_available,
|
||||
)
|
||||
|
||||
print(f' | Loading diffusers VAE from {name_or_path}')
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{'revision':'fp16'},{}]
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
deferred_error = None
|
||||
|
||||
# A VAE may be in a subfolder of a model's repository.
|
||||
if 'subfolder' in vae_config:
|
||||
vae_args['subfolder'] = vae_config['subfolder']
|
||||
|
||||
for fp_args in fp_args_list:
|
||||
# At some point we might need to be able to use different classes here? But for now I think
|
||||
# all Stable Diffusion VAE are AutoencoderKL.
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(name_or_path, **vae_args, **fp_args)
|
||||
except OSError as e:
|
||||
if str(e).startswith('fp16 is not a valid'):
|
||||
print(' | Half-precision version of model not available; fetching full-precision instead')
|
||||
else:
|
||||
deferred_error = e
|
||||
if vae:
|
||||
break
|
||||
|
||||
if not vae and deferred_error:
|
||||
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
|
||||
|
||||
return vae
|
@ -12,7 +12,7 @@ import os
|
||||
import re
|
||||
import atexit
|
||||
from ldm.invoke.args import Args
|
||||
from ldm.invoke.concepts_lib import Concepts
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
@ -24,7 +24,7 @@ except (ImportError,ModuleNotFoundError) as e:
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.vae','.safetensors')
|
||||
TEXT_EXTENSIONS = ('.txt','.TXT')
|
||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||
COMMANDS = (
|
||||
@ -59,7 +59,7 @@ COMMANDS = (
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
@ -67,8 +67,12 @@ MODEL_COMMANDS = (
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
CKPT_MODEL_COMMANDS = (
|
||||
'!optimize_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
'!convert_model',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
@ -91,9 +95,9 @@ weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
||||
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options, models=[]):
|
||||
def __init__(self, options, models={}):
|
||||
self.options = sorted(options)
|
||||
self.models = sorted(models)
|
||||
self.models = models
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
@ -134,6 +138,10 @@ class Completer(object):
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
elif re.search(weight_regexp,buffer):
|
||||
self.matches = self._path_completions(
|
||||
text,
|
||||
@ -242,18 +250,12 @@ class Completer(object):
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def add_model(self,model_name:str)->None:
|
||||
def update_models(self,models:dict)->None:
|
||||
'''
|
||||
add a model name to the completion list
|
||||
update our list of models
|
||||
'''
|
||||
self.models.append(model_name)
|
||||
|
||||
def del_model(self,model_name:str)->None:
|
||||
'''
|
||||
removes a model name from the completion list
|
||||
'''
|
||||
self.models.remove(model_name)
|
||||
|
||||
self.models = models
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
@ -278,7 +280,7 @@ class Completer(object):
|
||||
def _concept_completions(self, text, state):
|
||||
if self.concepts is None:
|
||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||
self.concepts = Concepts()
|
||||
self.concepts = HuggingFaceConceptsLibrary()
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
else:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
@ -294,7 +296,7 @@ class Completer(object):
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state):
|
||||
def _model_completions(self, text, state, ckpt_only=False):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
@ -304,6 +306,11 @@ class Completer(object):
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
format = self.models[s]['format']
|
||||
if format == 'vae':
|
||||
continue
|
||||
if ckpt_only and format != 'ckpt':
|
||||
continue
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
|
@ -12,6 +12,7 @@ def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
|
799
ldm/invoke/textual_inversion_training.py
Normal file
799
ldm/invoke/textual_inversion_training.py
Normal file
@ -0,0 +1,799 @@
|
||||
# This code was copied from
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||
# on January 2, 2023
|
||||
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import PIL
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# invokeai stuff
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--root_dir','--root',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A token to use as a placeholder for the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f'{Globals.root}/text-inversion-model',
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=Path,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL_INTERPOLATION["linear"],
|
||||
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def do_textual_inversion_training(
|
||||
model:str,
|
||||
train_data_dir:Path,
|
||||
output_dir:Path,
|
||||
placeholder_token:str,
|
||||
initializer_token:str,
|
||||
save_steps:int=500,
|
||||
only_save_embeds:bool=False,
|
||||
revision:str=None,
|
||||
tokenizer_name:str=None,
|
||||
learnable_property:str='object',
|
||||
repeats:int=100,
|
||||
seed:int=None,
|
||||
resolution:int=512,
|
||||
center_crop:bool=False,
|
||||
train_batch_size:int=16,
|
||||
num_train_epochs:int=100,
|
||||
max_train_steps:int=5000,
|
||||
gradient_accumulation_steps:int=1,
|
||||
gradient_checkpointing:bool=False,
|
||||
learning_rate:float=1e-4,
|
||||
scale_lr:bool=True,
|
||||
lr_scheduler:str='constant',
|
||||
lr_warmup_steps:int=500,
|
||||
adam_beta1:float=0.9,
|
||||
adam_beta2:float=0.999,
|
||||
adam_weight_decay:float=1e-02,
|
||||
adam_epsilon:float=1e-08,
|
||||
push_to_hub:bool=False,
|
||||
hub_token:str=None,
|
||||
logging_dir:Path=Path('logs'),
|
||||
mixed_precision:str='fp16',
|
||||
allow_tf32:bool=False,
|
||||
report_to:str='tensorboard',
|
||||
local_rank:int=-1,
|
||||
checkpointing_steps:int=500,
|
||||
resume_from_checkpoint:Path=None,
|
||||
enable_xformers_memory_efficient_attention:bool=False,
|
||||
root_dir:Path=None
|
||||
):
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != local_rank:
|
||||
local_rank = env_local_rank
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = os.path.join(Globals.root,output_dir)
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
mixed_precision=mixed_precision,
|
||||
log_with=report_to,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if push_to_hub:
|
||||
if hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(output_dir).name, token=hub_token)
|
||||
else:
|
||||
repo_name = hub_model_id
|
||||
repo = Repository(output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
models_conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||
model_conf = models_conf.get(model,None)
|
||||
assert model_conf is not None,f'Unknown model: {model}'
|
||||
assert model_conf.get('format','diffusers')=='diffusers', "This script only works with models of type 'diffusers'"
|
||||
pretrained_model_name_or_path = model_conf.get('repo_id',None) or Path(model_conf.get('path'))
|
||||
assert pretrained_model_name_or_path, f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir('diffusers'))
|
||||
|
||||
# Load tokenizer
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name,cache_dir=global_cache_dir('transformers'))
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, **pipeline_args
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision, **pipeline_args)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="unet", revision=revision, **pipeline_args
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||
unet.train()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if scale_lr:
|
||||
learning_rate = (
|
||||
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
weight_decay=adam_weight_decay,
|
||||
eps=adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=resolution,
|
||||
placeholder_token=placeholder_token,
|
||||
repeats=repeats,
|
||||
learnable_property=learnable_property,
|
||||
center_crop=center_crop,
|
||||
set="train",
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
if max_train_steps is None:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
params = locals()
|
||||
for k in params: # init_trackers() doesn't like objects
|
||||
params[k] = str(params[k]) if isinstance(params[k],object) else params[k]
|
||||
accelerator.init_trackers("textual_inversion", config=params)
|
||||
|
||||
# Train!
|
||||
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if resume_from_checkpoint:
|
||||
if resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % save_steps == 0:
|
||||
save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||
|
||||
if global_step % checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if push_to_hub and only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
**pipeline_args,
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||
|
||||
if push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
@ -4,7 +4,9 @@ from typing import Optional, Callable
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import diffusers
|
||||
from torch import nn
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
@ -337,8 +339,8 @@ def setup_cross_attention_control(model, context: Context):
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAICrossAttentionMixin
|
||||
# cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
isinstance(module, cross_attention_class) and which_attn in name]
|
||||
@ -441,3 +443,19 @@ def get_mem_free_total(device):
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
#default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||
return hidden_states
|
||||
|
||||
|
@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from omegaconf import ListConfig
|
||||
import urllib
|
||||
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ldm.util import (
|
||||
log_txt_as_img,
|
||||
exists,
|
||||
@ -678,6 +679,13 @@ class LatentDiffusion(DDPM):
|
||||
self.embedding_manager = self.instantiate_embedding_manager(
|
||||
personalization_config, self.cond_stage_model
|
||||
)
|
||||
self.textual_inversion_manager = TextualInversionManager(
|
||||
tokenizer = self.cond_stage_model.tokenizer,
|
||||
text_encoder = self.cond_stage_model.transformer,
|
||||
full_precision = True
|
||||
)
|
||||
# this circular component dependency is gross and bad, needs to be rethought
|
||||
self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager)
|
||||
|
||||
self.emb_ckpt_counter = 0
|
||||
|
||||
|
@ -209,12 +209,12 @@ class KSampler(Sampler):
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
|
||||
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||
attention_maps_saver = None
|
||||
attention_map_saver = None
|
||||
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
||||
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
@ -229,8 +229,8 @@ class KSampler(Sampler):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if attention_maps_saver is not None:
|
||||
attention_maps_callback(attention_maps_saver)
|
||||
if attention_map_saver is not None:
|
||||
attention_maps_callback(attention_map_saver)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
|
@ -1,14 +1,23 @@
|
||||
import traceback
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ThresholdSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@ -18,6 +27,7 @@ class InvokeAIDiffuserComponent:
|
||||
* Cross attention control ("prompt2prompt")
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
debug_thresholding = False
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
@ -36,6 +46,7 @@ class InvokeAIDiffuserComponent:
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
@ -77,7 +88,8 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None
|
||||
step_index: Optional[int]=None,
|
||||
threshold: Optional[ThresholdSettings]=None,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
@ -86,6 +98,7 @@ class InvokeAIDiffuserComponent:
|
||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||
:param threshold: threshold to apply after each step
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
@ -106,13 +119,13 @@ class InvokeAIDiffuserComponent:
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
||||
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||
|
||||
if threshold:
|
||||
combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
|
||||
|
||||
return combined_next_x
|
||||
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
@ -120,8 +133,11 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
|
||||
both_conditionings).chunk(2)
|
||||
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
@ -179,6 +195,51 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
return combined_next_x
|
||||
|
||||
def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor:
|
||||
warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf
|
||||
if warmup_scale < 1:
|
||||
# This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395
|
||||
warming_threshold = 1 + (threshold - 1) * warmup_scale
|
||||
current_threshold = np.clip(warming_threshold, 1, threshold)
|
||||
else:
|
||||
current_threshold = threshold
|
||||
|
||||
if current_threshold <= 0:
|
||||
return latents
|
||||
maxval = latents.max().item()
|
||||
minval = latents.min().item()
|
||||
|
||||
scale = 0.7 # default value from #395
|
||||
|
||||
if self.debug_thresholding:
|
||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||
print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
return latents
|
||||
|
||||
if maxval > current_threshold:
|
||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||
|
||||
if minval < -current_threshold:
|
||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||
|
||||
if self.debug_thresholding:
|
||||
outside = torch.count_nonzero((latents < minval) | (latents > maxval))
|
||||
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values will be clamped")
|
||||
|
||||
return latents.clamp(minval, maxval)
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
|
@ -162,7 +162,6 @@ def get_mem_free_total(device):
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os.path
|
||||
from cmath import log
|
||||
import torch
|
||||
from attr import dataclass
|
||||
from torch import nn
|
||||
|
||||
import sys
|
||||
|
||||
from ldm.invoke.concepts_lib import Concepts
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.data.personalized import per_img_token_list
|
||||
from transformers import CLIPTokenizer
|
||||
from functools import partial
|
||||
@ -14,36 +15,16 @@ from picklescan.scanner import scan_file_path
|
||||
PROGRESSIVE_SCALE = 2000
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(
|
||||
string,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt',
|
||||
)
|
||||
tokens = batch_encoding['input_ids']
|
||||
""" assert (
|
||||
torch.count_nonzero(tokens - 49407) == 2
|
||||
), f"String '{string}' maps to more than a single token. Please use another string" """
|
||||
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
return token_id
|
||||
|
||||
return tokens[0, 1]
|
||||
def get_embedding_for_clip_token_id(embedder, token_id):
|
||||
if type(token_id) is not torch.Tensor:
|
||||
token_id = torch.tensor(token_id, dtype=torch.int)
|
||||
return embedder(token_id.unsqueeze(0))[0, 0]
|
||||
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
|
||||
token = token[0, 1]
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_embedding_for_clip_token(embedder, token):
|
||||
return embedder(token.unsqueeze(0))[0, 0]
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -58,8 +39,7 @@ class EmbeddingManager(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.embedder = embedder
|
||||
self.concepts_library=Concepts()
|
||||
self.concepts_loaded = dict()
|
||||
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||
|
||||
self.string_to_token_dict = {}
|
||||
self.string_to_param_dict = nn.ParameterDict()
|
||||
@ -77,11 +57,11 @@ class EmbeddingManager(nn.Module):
|
||||
embedder, 'tokenizer'
|
||||
): # using Stable Diffusion's CLIP encoder
|
||||
self.is_clip = True
|
||||
get_token_for_string = partial(
|
||||
get_clip_token_for_string, embedder.tokenizer
|
||||
get_token_id_for_string = partial(
|
||||
get_clip_token_id_for_string, embedder.tokenizer
|
||||
)
|
||||
get_embedding_for_tkn = partial(
|
||||
get_embedding_for_clip_token,
|
||||
get_embedding_for_tkn_id = partial(
|
||||
get_embedding_for_clip_token_id,
|
||||
embedder.transformer.text_model.embeddings,
|
||||
)
|
||||
# per bug report #572
|
||||
@ -89,10 +69,10 @@ class EmbeddingManager(nn.Module):
|
||||
token_dim = 768
|
||||
else: # using LDM's BERT encoder
|
||||
self.is_clip = False
|
||||
get_token_for_string = partial(
|
||||
get_bert_token_for_string, embedder.tknz_fn
|
||||
get_token_id_for_string = partial(
|
||||
get_bert_token_id_for_string, embedder.tknz_fn
|
||||
)
|
||||
get_embedding_for_tkn = embedder.transformer.token_emb
|
||||
get_embedding_for_tkn_id = embedder.transformer.token_emb
|
||||
token_dim = 1280
|
||||
|
||||
if per_image_tokens:
|
||||
@ -100,15 +80,13 @@ class EmbeddingManager(nn.Module):
|
||||
|
||||
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||
|
||||
token = get_token_for_string(placeholder_string)
|
||||
token_id = get_token_id_for_string(placeholder_string)
|
||||
|
||||
if initializer_words and idx < len(initializer_words):
|
||||
init_word_token = get_token_for_string(initializer_words[idx])
|
||||
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
||||
|
||||
with torch.no_grad():
|
||||
init_word_embedding = get_embedding_for_tkn(
|
||||
init_word_token.cpu()
|
||||
)
|
||||
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
|
||||
|
||||
token_params = torch.nn.Parameter(
|
||||
init_word_embedding.unsqueeze(0).repeat(
|
||||
@ -132,7 +110,7 @@ class EmbeddingManager(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
self.string_to_token_dict[placeholder_string] = token
|
||||
self.string_to_token_dict[placeholder_string] = token_id
|
||||
self.string_to_param_dict[placeholder_string] = token_params
|
||||
|
||||
def forward(
|
||||
@ -140,6 +118,8 @@ class EmbeddingManager(nn.Module):
|
||||
tokenized_text,
|
||||
embedded_text,
|
||||
):
|
||||
# torch.save(embedded_text, '/tmp/embedding-manager-uglysonic-pre-rewrite.pt')
|
||||
|
||||
b, n, device = *tokenized_text.shape, tokenized_text.device
|
||||
|
||||
for (
|
||||
@ -164,7 +144,7 @@ class EmbeddingManager(nn.Module):
|
||||
)
|
||||
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(tokenized_text.device)
|
||||
tokenized_text == placeholder_token
|
||||
)
|
||||
|
||||
if placeholder_rows.nelement() == 0:
|
||||
@ -182,9 +162,7 @@ class EmbeddingManager(nn.Module):
|
||||
new_token_row = torch.cat(
|
||||
[
|
||||
tokenized_text[row][:col],
|
||||
placeholder_token.repeat(num_vectors_for_token).to(
|
||||
device
|
||||
),
|
||||
torch.tensor([placeholder_token] * num_vectors_for_token, device=device),
|
||||
tokenized_text[row][col + 1 :],
|
||||
],
|
||||
axis=0,
|
||||
@ -212,22 +190,6 @@ class EmbeddingManager(nn.Module):
|
||||
ckpt_path,
|
||||
)
|
||||
|
||||
def load_concepts(self, concepts:list[str], full=True):
|
||||
bin_files = list()
|
||||
for concept_name in concepts:
|
||||
if concept_name in self.concepts_loaded:
|
||||
continue
|
||||
else:
|
||||
bin_file = self.concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
bin_files.append(bin_file)
|
||||
self.concepts_loaded[concept_name]=True
|
||||
self.load(bin_files, full)
|
||||
|
||||
def list_terms(self) -> list[str]:
|
||||
return self.concepts_loaded.keys()
|
||||
|
||||
def load(self, ckpt_paths, full=True):
|
||||
if len(ckpt_paths) == 0:
|
||||
return
|
||||
@ -282,14 +244,16 @@ class EmbeddingManager(nn.Module):
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
|
||||
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
||||
existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||
if existing_token_id == self.embedder.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
||||
|
||||
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
|
||||
self.string_to_token_dict[token_str] = token
|
||||
token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||
self.string_to_token_dict[token_str] = token_id
|
||||
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
||||
|
||||
def parse_embedding(self, embedding_file: str):
|
||||
@ -318,7 +282,7 @@ class EmbeddingManager(nn.Module):
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
except (AttributeError,KeyError):
|
||||
return self.handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
|
@ -1,5 +1,7 @@
|
||||
import math
|
||||
import os.path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -8,7 +10,8 @@ from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
#from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
@ -106,7 +109,7 @@ class BERTTokenizer(AbstractEncoder):
|
||||
BertTokenizerFast,
|
||||
)
|
||||
|
||||
cache = os.path.join(Globals.root,'models/bert-base-uncased')
|
||||
cache = global_cache_dir('hub')
|
||||
try:
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||
'bert-base-uncased',
|
||||
@ -235,26 +238,28 @@ class SpatialRescaler(nn.Module):
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
tokenizer: CLIPTokenizer
|
||||
transformer: CLIPTextModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device=choose_torch_device(),
|
||||
max_length=77,
|
||||
version:str='openai/clip-vit-large-patch14',
|
||||
max_length:int=77,
|
||||
tokenizer:Optional[CLIPTokenizer]=None,
|
||||
transformer:Optional[CLIPTextModel]=None,
|
||||
):
|
||||
super().__init__()
|
||||
cache = os.path.join(Globals.root,'models',version)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
cache = global_cache_dir('hub')
|
||||
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
|
||||
version,
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
)
|
||||
self.transformer = CLIPTextModel.from_pretrained(
|
||||
self.transformer = transformer or CLIPTextModel.from_pretrained(
|
||||
version,
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.freeze()
|
||||
|
||||
@ -460,12 +465,25 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
@device.setter
|
||||
def device(self, device):
|
||||
self.transformer.to(device=device)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
return_tokens_key = "return_tokens"
|
||||
|
||||
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
|
||||
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||
self.textual_inversion_manager = manager
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||
'''
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
@ -560,19 +578,43 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
tokens = self.tokenizer(
|
||||
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
"""
|
||||
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
||||
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
||||
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
|
||||
(typically 75 tokens + eos/bos markers).
|
||||
|
||||
:param fragments: The strings to convert.
|
||||
:param include_start_and_end_markers:
|
||||
:return:
|
||||
"""
|
||||
|
||||
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
|
||||
token_ids_list = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
return_tensors=None, # just give me lists of ints
|
||||
)['input_ids']
|
||||
if include_start_and_end_markers:
|
||||
return tokens
|
||||
else:
|
||||
return [x[1:-1] for x in tokens]
|
||||
|
||||
result = []
|
||||
for token_ids in token_ids_list:
|
||||
# trim eos/bos
|
||||
token_ids = token_ids[1:-1]
|
||||
# pad for textual inversions with vector length >1
|
||||
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
|
||||
# restrict length to max_length-2 (leaving room for bos/eos)
|
||||
token_ids = token_ids[0:self.max_length - 2]
|
||||
# add back eos/bos if requested
|
||||
if include_start_and_end_markers:
|
||||
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
||||
|
||||
result.append(token_ids)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
@ -597,56 +639,58 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
weights = [1]
|
||||
item_encodings = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=True,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
||||
all_token_ids = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
for index, fragment in enumerate(per_fragment_token_ids):
|
||||
weight = float(weights[index])
|
||||
#print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
this_fragment_token_ids = per_fragment_token_ids[index]
|
||||
#print("fragment", fragment, "processed to", this_fragment_token_ids)
|
||||
# append
|
||||
all_token_ids += this_fragment_token_ids
|
||||
# fill out weights tensor with one float per token
|
||||
per_token_weights += [weight] * len(this_fragment_token_ids)
|
||||
|
||||
if (len(all_tokens) + 2) > self.max_length:
|
||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
||||
# leave room for bos/eos
|
||||
if len(all_token_ids) > self.max_length - 2:
|
||||
excess_token_count = len(all_token_ids) - self.max_length - 2
|
||||
# TODO build nice description string of how the truncation was applied
|
||||
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
per_token_weights = per_token_weights[:self.max_length - 2]
|
||||
all_token_ids = all_token_ids[0:self.max_length]
|
||||
per_token_weights = per_token_weights[0:self.max_length]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (77 = self.max_length)
|
||||
pad_length = self.max_length - 1 - len(all_tokens)
|
||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
||||
per_token_weights.insert(0, 1)
|
||||
per_token_weights.extend([1] * pad_length)
|
||||
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||
pad_length = self.max_length - len(all_token_ids)
|
||||
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||
per_token_weights += [1.0] * pad_length
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||
return all_token_ids_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
||||
:param token_ids: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
||||
if token_ids.shape != torch.Size([self.max_length]):
|
||||
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||
|
||||
z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs)
|
||||
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
@ -660,7 +704,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
|
236
ldm/modules/prompt_to_embeddings_converter.py
Normal file
236
ldm/modules/prompt_to_embeddings_converter.py
Normal file
@ -0,0 +1,236 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: CLIPTokenizer, # converts strings to lists of int token ids
|
||||
text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings
|
||||
textual_inversion_manager: TextualInversionManager = None
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.textual_inversion_manager = textual_inversion_manager
|
||||
|
||||
@property
|
||||
def max_length(self):
|
||||
return self.tokenizer.model_max_length
|
||||
|
||||
def get_embeddings_for_weighted_prompt_fragments(self,
|
||||
text: list[str],
|
||||
fragment_weights: list[float],
|
||||
should_return_tokens: bool = False,
|
||||
device='cpu'
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
|
||||
:param text: A list of fragments of text to which different weights are to be applied.
|
||||
:param fragment_weights: A batch of lists of weights, one for each entry in `fragments`.
|
||||
:return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1
|
||||
and 1280 for SD2
|
||||
'''
|
||||
if len(text) != len(fragment_weights):
|
||||
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(text)} != {len(fragment_weights)})")
|
||||
|
||||
batch_z = None
|
||||
batch_tokens = None
|
||||
for fragments, weights in zip(text, fragment_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
per_embedding_weights = [1.0]
|
||||
|
||||
# now handle weights <1
|
||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||
# removed.
|
||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||
# so let's use tan(), because:
|
||||
# tan is 0.0 at 0,
|
||||
# 1.0 at PI/4, and
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
if should_return_tokens:
|
||||
return batch_z, batch_tokens
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
"""
|
||||
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
||||
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
||||
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
|
||||
(typically 75 tokens + eos/bos markers).
|
||||
|
||||
:param fragments: The strings to convert.
|
||||
:param include_start_and_end_markers:
|
||||
:return:
|
||||
"""
|
||||
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
|
||||
token_ids_list = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me lists of ints
|
||||
)['input_ids']
|
||||
|
||||
result = []
|
||||
for token_ids in token_ids_list:
|
||||
# trim eos/bos
|
||||
token_ids = token_ids[1:-1]
|
||||
# pad for textual inversions with vector length >1
|
||||
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
|
||||
# restrict length to max_length-2 (leaving room for bos/eos)
|
||||
token_ids = token_ids[0:self.max_length - 2]
|
||||
# add back eos/bos if requested
|
||||
if include_start_and_end_markers:
|
||||
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
||||
|
||||
result.append(token_ids)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences
|
||||
together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded
|
||||
or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in
|
||||
weights to match each token.
|
||||
|
||||
:param fragments: Text fragments to tokenize and concatenate. May be empty.
|
||||
:param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5
|
||||
values >1.6 tend to produce garbage output. Must have same length as `fragment`.
|
||||
:return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`.
|
||||
'''
|
||||
if len(fragments) != len(weights):
|
||||
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})")
|
||||
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0:
|
||||
fragments = ['']
|
||||
weights = [1.0]
|
||||
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
||||
all_token_ids = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights):
|
||||
# append
|
||||
all_token_ids += this_fragment_token_ids
|
||||
# fill out weights tensor with one float per token
|
||||
per_token_weights += [float(weight)] * len(this_fragment_token_ids)
|
||||
|
||||
# leave room for bos/eos
|
||||
max_token_count_without_bos_eos_markers = self.max_length - 2
|
||||
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
|
||||
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
|
||||
# TODO build nice description string of how the truncation was applied
|
||||
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
||||
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
||||
|
||||
# pad out to a self.max_length-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (typically self.max_length == 77)
|
||||
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||
pad_length = self.max_length - len(all_token_ids)
|
||||
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||
per_token_weights += [1.0] * pad_length
|
||||
|
||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device)
|
||||
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||
return all_token_ids_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
|
||||
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
|
||||
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
|
||||
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
|
||||
where `token_dim` is 768 for SD1 and 1280 for SD2.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
if token_ids.shape != torch.Size([self.max_length]):
|
||||
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||
|
||||
z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
|
||||
return_dict=False)[0]
|
||||
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
|
||||
[self.tokenizer.pad_token_id] * (self.max_length-2) +
|
||||
[self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
|
||||
empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
return weighted_z
|
293
ldm/modules/textual_inversion_manager.py
Normal file
293
ldm/modules/textual_inversion_manager.py
Normal file
@ -0,0 +1,293 @@
|
||||
import os
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
embedding: torch.Tensor
|
||||
trigger_token_id: Optional[int] = None
|
||||
pad_token_ids: Optional[list[int]] = None
|
||||
|
||||
@property
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
class TextualInversionManager():
|
||||
def __init__(self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool=True):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
def load_huggingface_concepts(self, concepts: list[str]):
|
||||
for concept_name in concepts:
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
return
|
||||
except Exception:
|
||||
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
try:
|
||||
self._add_textual_inversion(embedding_info['name'],
|
||||
embedding_info['embedding'],
|
||||
defer_injecting_tokens=defer_injecting_tokens)
|
||||
except ValueError:
|
||||
print(f' | ignoring incompatible embedding {embedding_info["name"]}')
|
||||
else:
|
||||
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||
|
||||
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
|
||||
|
||||
try:
|
||||
ti = TextualInversion(
|
||||
trigger_string=trigger_str,
|
||||
embedding=embedding
|
||||
)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith('Warning'):
|
||||
print(f">> {str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
||||
|
||||
print(f'DEBUG: Injecting token {ti.trigger_string}')
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
ti.trigger_token_id = trigger_token_id
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
return ti is not None
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
print(f' | ignoring incompatible embedding trigger {ti.trigger_string}')
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
||||
"""
|
||||
if len(prompt_token_ids) == 0:
|
||||
return prompt_token_ids
|
||||
|
||||
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
|
||||
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
# the following call is slow - todo make batched for better performance with vector length >1
|
||||
self.text_encoder.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split('.')[-1]
|
||||
if file_type == 'pt':
|
||||
return self._parse_embedding_pt(embedding_file)
|
||||
elif file_type == 'bin':
|
||||
return self._parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||
|
||||
def _parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||
if embedding_info['num_of_embeddings'] > 1:
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||
embedding_info['token_dim'] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys())==1:
|
||||
print('>> Detected .bin file masquerading as .pt file')
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(">> Invalid concepts file")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt[token]
|
||||
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||
'''
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
'''
|
||||
embedding_info = {}
|
||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt['string_to_token'].keys()):
|
||||
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
52
ldm/util.py
52
ldm/util.py
@ -1,17 +1,18 @@
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from collections import abc
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
import multiprocessing as mp
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import traceback
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
@ -250,7 +251,7 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
image_copy = debug_image.copy()
|
||||
image_copy = debug_image.copy().convert("RGBA")
|
||||
ImageDraw.Draw(image_copy).text(
|
||||
(5, 5),
|
||||
debug_text,
|
||||
@ -262,3 +263,32 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
|
||||
|
||||
if debug_result:
|
||||
return image_copy
|
||||
|
||||
#-------------------------------------
|
||||
class ProgressBar():
|
||||
def __init__(self,model_name='file'):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num, block_size, total_size):
|
||||
if not self.pbar:
|
||||
self.pbar=tqdm(desc=self.name,
|
||||
initial=0,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
def download_with_progress_bar(url:str, dest:Path)->bool:
|
||||
try:
|
||||
if not os.path.exists(dest):
|
||||
os.makedirs((os.path.dirname(dest) or '.'), exist_ok=True)
|
||||
request.urlretrieve(url,dest,ProgressBar(os.path.basename(dest)))
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
except OSError:
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
@ -8,32 +8,43 @@
|
||||
#
|
||||
print('Loading Python libraries...\n')
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import warnings
|
||||
import shutil
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from huggingface_hub import HfFolder, hf_hub_url, login as hf_hub_login
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
from urllib import request
|
||||
|
||||
import requests
|
||||
import transformers
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||
from getpass_asterisk import getpass_asterisk
|
||||
from huggingface_hub import HfFolder, hf_hub_url, login as hf_hub_login, whoami as hf_whoami
|
||||
from huggingface_hub.utils._errors import RevisionNotFoundError
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
import traceback
|
||||
import requests
|
||||
import clip
|
||||
import transformers
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import torch
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
try:
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
except ImportError:
|
||||
sys.path.append('.')
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
#--------------------------globals-----------------------
|
||||
Model_dir = 'models'
|
||||
Weights_dir = 'ldm/stable-diffusion-v1/'
|
||||
@ -150,14 +161,15 @@ will be given the option to view and change your selections.
|
||||
'''
|
||||
)
|
||||
for ds in Datasets.keys():
|
||||
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
||||
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
||||
recommended = Datasets[ds].get('recommended',False)
|
||||
r_str = '(recommended)' if recommended else ''
|
||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {r_str}')
|
||||
if yes_or_no(' Download?',default_yes=recommended):
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
else:
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds]['recommended']:
|
||||
if Datasets[ds].get('recommended',False):
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
|
||||
@ -181,7 +193,7 @@ will be given the option to view and change your selections.
|
||||
def recommended_datasets()->dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds]['recommended']:
|
||||
if Datasets[ds].get('recommended',False):
|
||||
datasets[ds]=True
|
||||
return datasets
|
||||
|
||||
@ -240,6 +252,7 @@ The license terms are located here:
|
||||
print("=" * shutil.get_terminal_size()[0])
|
||||
print('Authenticating to Huggingface')
|
||||
hf_envvars = [ "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN" ]
|
||||
token_found = False
|
||||
if not (access_token := HfFolder.get_token()):
|
||||
print(f"Huggingface token not found in cache.")
|
||||
|
||||
@ -257,17 +270,21 @@ The license terms are located here:
|
||||
print(f"Huggingface token found in cache.")
|
||||
try:
|
||||
HfLogin(access_token)
|
||||
token_found = True
|
||||
except ValueError:
|
||||
print(f"Login failed due to invalid token found in cache")
|
||||
|
||||
if not yes_to_all:
|
||||
print('''
|
||||
You may optionally enter your Huggingface token now. InvokeAI *will* work without it, but some functionality may be limited.
|
||||
See https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/#using-a-hugging-face-concept for more information.
|
||||
if not (yes_to_all or token_found):
|
||||
print(''' You may optionally enter your Huggingface token now. InvokeAI
|
||||
*will* work without it but you will not be able to automatically
|
||||
download some of the Hugging Face style concepts. See
|
||||
https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/#using-a-hugging-face-concept
|
||||
for more information.
|
||||
|
||||
Visit https://huggingface.co/settings/tokens to generate a token. (Sign up for an account if needed).
|
||||
|
||||
Paste the token below using Ctrl-Shift-V (macOS/Linux) or right-click (Windows), and/or 'Enter' to continue.
|
||||
Paste the token below using Ctrl-V on macOS/Linux, or Ctrl-Shift-V or right-click on Windows.
|
||||
Alternatively press 'Enter' to skip this step and continue.
|
||||
You may re-run the configuration script again in the future if you do not wish to set the token right now.
|
||||
''')
|
||||
again = True
|
||||
@ -313,34 +330,61 @@ def migrate_models_ckpt():
|
||||
os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
||||
|
||||
#---------------------------------------------
|
||||
def download_weight_datasets(models:dict, access_token:str):
|
||||
def download_weight_datasets(models:dict, access_token:str, precision:str='float32'):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models.keys():
|
||||
repo_id = Datasets[mod]['repo_id']
|
||||
filename = Datasets[mod]['file']
|
||||
dest = os.path.join(Globals.root,Model_dir,Weights_dir)
|
||||
success = hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=dest,
|
||||
model_name=filename,
|
||||
access_token=access_token
|
||||
)
|
||||
if success:
|
||||
successful[mod] = True
|
||||
if len(successful) < len(models):
|
||||
print(f'\n\n** There were errors downloading one or more files. **')
|
||||
print('Press any key to try again. Type ^C to quit.\n')
|
||||
input()
|
||||
return None
|
||||
|
||||
keys = ', '.join(successful.keys())
|
||||
print(f'Successfully installed {keys}')
|
||||
print(f'{mod}...',file=sys.stderr,end='')
|
||||
successful[mod] = _download_repo_or_file(Datasets[mod], access_token, precision=precision)
|
||||
return successful
|
||||
|
||||
def _download_repo_or_file(mconfig:DictConfig, access_token:str, precision:str='float32')->Path:
|
||||
path = None
|
||||
if mconfig['format'] == 'ckpt':
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if 'vae' in mconfig and 'repo_id' in mconfig['vae']:
|
||||
_download_diffusion_weights(mconfig['vae'], access_token, precision=precision)
|
||||
return path
|
||||
|
||||
def _download_ckpt_weights(mconfig:DictConfig, access_token:str)->Path:
|
||||
repo_id = mconfig['repo_id']
|
||||
filename = mconfig['file']
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token
|
||||
)
|
||||
|
||||
def _download_diffusion_weights(mconfig:DictConfig, access_token:str, precision:str='float32'):
|
||||
repo_id = mconfig['repo_id']
|
||||
model_class = StableDiffusionGeneratorPipeline if mconfig.get('format',None)=='diffusers' else AutoencoderKL
|
||||
extra_arg_list = [{'revision':'fp16'},{}] if precision=='float16' else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir='diffusers',
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith('fp16 is not a valid'):
|
||||
print(f'Could not fetch half-precision version of model {repo_id}; fetching full-precision instead')
|
||||
else:
|
||||
print(f'An unexpected error occurred while downloading the model: {e})')
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
#---------------------------------------------
|
||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||
model_dest = os.path.join(model_dir, model_name)
|
||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
@ -359,7 +403,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
||||
|
||||
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||
print(f'* {model_name}: complete file found. Skipping.')
|
||||
return True
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f'** An error occurred during downloading {model_name}: {resp.reason}')
|
||||
elif exist_size > 0:
|
||||
@ -370,7 +414,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f'*** ERROR DOWNLOADING {model_name}: {resp.text}')
|
||||
return False
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
@ -385,8 +429,22 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f'An error occurred while downloading {model_name}: {str(e)}')
|
||||
return False
|
||||
return True
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
# -----------------------------------------------------------------------------------
|
||||
#---------------------------------------------
|
||||
def is_huggingface_authenticated():
|
||||
# huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError,
|
||||
# maybe other things, not all end-user-friendly.
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
response = hf_whoami()
|
||||
if response.get('id') is not None:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
#---------------------------------------------
|
||||
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
||||
@ -404,7 +462,6 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
||||
print(f'Error downloading {label} model')
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
#---------------------------------------------
|
||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||
config_file = opt.config_file or Default_config_file
|
||||
@ -441,29 +498,27 @@ def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str
|
||||
default_selected = False
|
||||
|
||||
for model in successfully_downloaded:
|
||||
a = Datasets[model]['config'].split('/')
|
||||
if a[0] != 'VAE':
|
||||
continue
|
||||
vae_target = a[1] if len(a)>1 else 'default'
|
||||
vaes[vae_target] = Datasets[model]['file']
|
||||
|
||||
for model in successfully_downloaded:
|
||||
if Datasets[model]['config'].startswith('VAE'): # skip VAE entries
|
||||
continue
|
||||
stanza = conf[model] if model in conf else { }
|
||||
|
||||
stanza['description'] = Datasets[model]['description']
|
||||
stanza['weights'] = os.path.join(Model_dir,Weights_dir,Datasets[model]['file'])
|
||||
stanza['config'] = os.path.normpath(os.path.join(SD_Configs, Datasets[model]['config']))
|
||||
stanza['width'] = Datasets[model]['width']
|
||||
stanza['height'] = Datasets[model]['height']
|
||||
mod = Datasets[model]
|
||||
stanza['description'] = mod['description']
|
||||
stanza['repo_id'] = mod['repo_id']
|
||||
stanza['format'] = mod['format']
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if 'width' in mod:
|
||||
stanza['width'] = mod['width']
|
||||
if 'height' in mod:
|
||||
stanza['height'] = mod['height']
|
||||
if 'file' in mod:
|
||||
stanza['weights'] = os.path.relpath(successfully_downloaded[model], start=Globals.root)
|
||||
stanza['config'] = os.path.normpath(os.path.join(SD_Configs,mod['config']))
|
||||
if 'vae' in mod:
|
||||
if 'file' in mod['vae']:
|
||||
stanza['vae'] = os.path.normpath(os.path.join(Model_dir, Weights_dir,mod['vae']['file']))
|
||||
else:
|
||||
stanza['vae'] = mod['vae']
|
||||
stanza.pop('default',None) # this will be set later
|
||||
if vaes:
|
||||
for target in vaes:
|
||||
if re.search(target, model, flags=re.IGNORECASE):
|
||||
stanza['vae'] = os.path.normpath(os.path.join(Model_dir,Weights_dir,vaes[target]))
|
||||
else:
|
||||
stanza['vae'] = os.path.normpath(os.path.join(Model_dir,Weights_dir,vaes['default']))
|
||||
|
||||
# BUG - the first stanza is always the default. User should select.
|
||||
if not default_selected:
|
||||
stanza['default'] = True
|
||||
@ -477,17 +532,20 @@ def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
from transformers import BertTokenizerFast
|
||||
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_from_hf(model_class:object, model_name:str):
|
||||
def download_from_hf(model_class:object, model_name:str, cache_subdir:Path=Path('hub'), **kwargs):
|
||||
print('',file=sys.stderr) # to prevent tqdm from overwriting
|
||||
return model_class.from_pretrained(model_name,
|
||||
cache_dir=os.path.join(Globals.root,Model_dir,model_name),
|
||||
resume_download=True
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
return path if model else None
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
@ -585,11 +643,13 @@ def download_safety_checker():
|
||||
#-------------------------------------
|
||||
def download_weights(opt:dict) -> Union[str, None]:
|
||||
|
||||
precision = 'float32' if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
|
||||
if opt.yes_to_all:
|
||||
models = recommended_datasets()
|
||||
access_token = authenticate(opt.yes_to_all)
|
||||
if len(models)>0:
|
||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
return
|
||||
|
||||
@ -607,11 +667,11 @@ def download_weights(opt:dict) -> Union[str, None]:
|
||||
else: # 'skip'
|
||||
return
|
||||
|
||||
|
||||
access_token = authenticate()
|
||||
HfFolder.save_token(access_token)
|
||||
|
||||
print('\n** DOWNLOADING WEIGHTS **')
|
||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
|
||||
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
if len(successfully_downloaded) < len(models):
|
||||
@ -738,6 +798,12 @@ def main():
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help='skip downloading the large Stable Diffusion weight files')
|
||||
parser.add_argument('--full-precision',
|
||||
dest='full_precision',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help='use 32-bit weights instead of faster 16-bit weights')
|
||||
parser.add_argument('--yes','-y',
|
||||
dest='yes_to_all',
|
||||
action='store_true',
|
||||
|
156
scripts/merge_fe.py
Normal file
156
scripts/merge_fe.py
Normal file
@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
import argparse
|
||||
from ldm.invoke.globals import Globals, global_set_root
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" %(self.value, self.out_of)
|
||||
l = (len(str(self.out_of)))*2+4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
interpolations = ['weighted_sum',
|
||||
'sigmoid',
|
||||
'inv_sigmoid',
|
||||
'add_difference']
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names = self.get_model_names()
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Select up to three models to merge",
|
||||
value=''
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='First Model:',
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Second Model:',
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0,'None')
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Third Model:',
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1,
|
||||
)
|
||||
|
||||
for m in [self.model1,self.model2,self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Merge Method:',
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
max_height=len(self.interpolations),
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name='Weight (alpha) to assign to second and third models:',
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
)
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Name for merged model',
|
||||
value='',
|
||||
)
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values=['add_difference'],
|
||||
self.merged_model_name.value += f'+{models[selected_model3]}'
|
||||
else:
|
||||
self.merge_method.values=self.interpolations
|
||||
self.merge_method.value=0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self)->bool:
|
||||
bad_fields = []
|
||||
selected_models = set((self.model1.value[0],self.model2.value[0],self.model3.value[0]))
|
||||
if len(selected_models) < 3:
|
||||
bad_fields.append('Please select two or three DIFFERENT models to compare')
|
||||
if len(bad_fields) > 0:
|
||||
message = 'The following problems were detected and must be corrected:'
|
||||
for problem in bad_fields:
|
||||
message += f'\n* {problem}'
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self)->List[str]:
|
||||
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||
model_names = [name for name in conf.keys() if conf[name].get('format',None)=='diffusers']
|
||||
return sorted(model_names)
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm('MAIN', mergeModelsForm, name='Merge Models Settings')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||
parser.add_argument(
|
||||
'--root_dir','--root-dir',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help='Path to the invokeai runtime directory',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
myapplication = MyApplication()
|
||||
myapplication.run()
|
11
scripts/textual_inversion.py
Executable file
11
scripts/textual_inversion.py
Executable file
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2023, Lincoln Stein @lstein
|
||||
from ldm.invoke.globals import Globals, set_root
|
||||
from ldm.invoke.textual_inversion_training import parse_args, do_textual_inversion_training
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
set_root(args.root_dir or Globals.root)
|
||||
kwargs = vars(args)
|
||||
do_textual_inversion_training(**kwargs)
|
333
scripts/textual_inversion_fe.py
Executable file
333
scripts/textual_inversion_fe.py
Executable file
@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
from ldm.invoke.globals import Globals, global_set_root
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import argparse
|
||||
|
||||
TRAINING_DATA = 'training-data'
|
||||
TRAINING_DIR = 'text-inversion-training'
|
||||
CONF_FILE = 'preferences.conf'
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear", "cosine", "cosine_with_restarts",
|
||||
"polynomial","constant", "constant_with_warmup"
|
||||
]
|
||||
precisions = ['no','fp16','bf16']
|
||||
learnable_properties = ['object','style']
|
||||
|
||||
def __init__(self, parentApp, name, saved_args=None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = '★'
|
||||
default_placeholder_token = ''
|
||||
saved_args = self.saved_args
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args['model'])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Model Name:',
|
||||
values=self.model_names,
|
||||
value=default,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Trigger Term:',
|
||||
value='', # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value='',
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Initializer:',
|
||||
value=saved_args.get('initializer_token',default_initializer_token),
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(saved_args.get('learnable_property','object')),
|
||||
max_height=4,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilenameCombo,
|
||||
name='Data Training Directory:',
|
||||
select_dir=True,
|
||||
must_exist=True,
|
||||
value=saved_args.get('train_data_dir',Path(Globals.root) / TRAINING_DATA / default_placeholder_token)
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilenameCombo,
|
||||
name='Output Destination Directory:',
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=saved_args.get('output_dir',Path(Globals.root) / TRAINING_DIR / default_placeholder_token)
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Image resolution (pixels):',
|
||||
values = self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get('resolution',512)),
|
||||
scroll_exit = True,
|
||||
max_height=4,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get('center_crop',False)
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Mixed Precision:',
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get('mixed_precision','fp16')),
|
||||
max_height=4,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Max Training Steps:',
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get('max_train_steps',3000)
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Batch Size (reduce if you run out of memory):',
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get('train_batch_size',8),
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(saved_args.get('learning_rate','5.0e-04'),)
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get('scale_lr',True),
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get('enable_xformers_memory_efficient_attention',False),
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Learning rate scheduler:',
|
||||
values = self.lr_schedulers,
|
||||
max_height=7,
|
||||
scroll_exit = True,
|
||||
value=self.lr_schedulers.index(saved_args.get('lr_scheduler','constant')),
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Gradient Accumulation Steps:',
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get('gradient_accumulation_steps',4)
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Warmup Steps:',
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get('lr_warmup_steps',0),
|
||||
)
|
||||
|
||||
def initializer_changed(self):
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f'(Trigger by using <{placeholder}> in your prompts)'
|
||||
self.train_data_dir.value = Path(Globals.root) / TRAINING_DATA / placeholder
|
||||
self.output_dir.value = Path(Globals.root) / TRAINING_DIR / placeholder
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify('Launching textual inversion training. This will take a while...')
|
||||
# The module load takes a while, so we do it while the form and message are still up
|
||||
import ldm.invoke.textual_inversion_training
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self)->bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append('Model Name must correspond to a known model in models.yaml')
|
||||
if not re.match('^[a-zA-Z0-9.-]+$',self.placeholder_token.value):
|
||||
bad_fields.append('Trigger term must only contain alphanumeric characters, the dot and hyphen')
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append('Data Training Directory cannot be empty')
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append('The Output Destination Directory cannot be empty')
|
||||
if len(bad_fields) > 0:
|
||||
message = 'The following problems were detected and must be corrected:'
|
||||
for problem in bad_fields:
|
||||
message += f'\n* {problem}'
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self)->(List[str],int):
|
||||
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||
model_names = list(conf.keys())
|
||||
defaults = [idx for idx in range(len(model_names)) if 'default' in conf[model_names[idx]]]
|
||||
return (model_names,defaults[0])
|
||||
|
||||
def marshall_arguments(self)->dict:
|
||||
args = dict()
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model = self.model_names[self.model.value[0]],
|
||||
resolution = self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler = self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision = self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property = self.learnable_properties[self.learnable_property.value[0]],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in ('initializer_token','placeholder_token','train_data_dir',
|
||||
'output_dir','scale_lr','center_crop','enable_xformers_memory_efficient_attention'):
|
||||
args[attr] = getattr(self,attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in ('train_batch_size','gradient_accumulation_steps',
|
||||
'max_train_steps','lr_warmup_steps'):
|
||||
args[attr] = int(getattr(self,attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(
|
||||
learning_rate = float(self.learning_rate.value)
|
||||
)
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args['resume_from_checkpoint'] = 'latest'
|
||||
|
||||
return args
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args=None):
|
||||
super().__init__()
|
||||
self.ti_arguments=None
|
||||
self.saved_args=saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm('MAIN', textualInversionForm, name='Textual Inversion Settings', saved_args=self.saved_args)
|
||||
|
||||
def copy_to_embeddings_folder(args:dict):
|
||||
'''
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
'''
|
||||
source = Path(args['output_dir'],'learned_embeds.bin')
|
||||
dest_dir_name = args['placeholder_token'].strip('<>')
|
||||
destination = Path(Globals.root,'embeddings',dest_dir_name)
|
||||
os.makedirs(destination,exist_ok=True)
|
||||
print(f'>> Training completed. Copying learned_embeds.bin into {str(destination)}')
|
||||
shutil.copy(source,destination)
|
||||
if (input('Delete training logs and intermediate checkpoints? [y] ') or 'y').startswith(('y','Y')):
|
||||
shutil.rmtree(Path(args['output_dir']))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
|
||||
def save_args(args:dict):
|
||||
'''
|
||||
Save the current argument values to an omegaconf file
|
||||
'''
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
def previous_args()->dict:
|
||||
'''
|
||||
Get the previous arguments used.
|
||||
'''
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf['placeholder_token'] = conf['placeholder_token'].strip('<>')
|
||||
except:
|
||||
conf= None
|
||||
|
||||
return conf
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||
parser.add_argument(
|
||||
'--root_dir','--root-dir',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help='Path to the invokeai runtime directory',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
from ldm.invoke.textual_inversion_training import do_textual_inversion_training
|
||||
if args := myapplication.ti_arguments:
|
||||
os.makedirs(args['output_dir'],exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match('^<.+>$',args['placeholder_token']):
|
||||
args['placeholder_token'] = f"<{args['placeholder_token']}>"
|
||||
|
||||
args['only_save_embeds'] = True
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print('** An exception occurred during training. The exception was:')
|
||||
print(str(e))
|
||||
print('** DETAILS:')
|
||||
print(traceback.format_exc())
|
12
setup.py
12
setup.py
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from setuptools import setup, find_packages
|
||||
@ -9,7 +10,13 @@ def list_files(directory):
|
||||
listing.append(pair)
|
||||
return listing
|
||||
|
||||
VERSION = '2.2.5'
|
||||
|
||||
def get_version()->str:
|
||||
from ldm.invoke import __version__ as version
|
||||
return version
|
||||
|
||||
# The canonical version number is stored in the file ldm/invoke/_version.py
|
||||
VERSION = get_version()
|
||||
DESCRIPTION = ('An implementation of Stable Diffusion which provides various new features'
|
||||
' and options to aid the image generation process')
|
||||
LONG_DESCRIPTION = ('This version of Stable Diffusion features a slick WebGUI, an'
|
||||
@ -85,7 +92,8 @@ setup(
|
||||
'Topic :: Scientific/Engineering :: Image Processing',
|
||||
],
|
||||
scripts = ['scripts/invoke.py','scripts/configure_invokeai.py', 'scripts/sd-metadata.py',
|
||||
'scripts/preload_models.py', 'scripts/images2prompt.py','scripts/merge_embeddings.py'
|
||||
'scripts/preload_models.py', 'scripts/images2prompt.py','scripts/merge_embeddings.py',
|
||||
'scripts/textual_inversion_fe.py','scripts/textual_inversion.py'
|
||||
],
|
||||
data_files=FRONTEND_FILES,
|
||||
)
|
||||
|
14
tests/inpainting/coyote-inpainting.prompt
Normal file
14
tests/inpainting/coyote-inpainting.prompt
Normal file
@ -0,0 +1,14 @@
|
||||
# 🌻 🌻 🌻 sunflowers 🌻 🌻 🌻
|
||||
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.2
|
||||
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.4
|
||||
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.6
|
||||
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.8
|
||||
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.99
|
||||
|
||||
# 🌹 🌹 🌹 roses 🌹 🌹 🌹
|
||||
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.2
|
||||
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.4
|
||||
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.6
|
||||
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.8
|
||||
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.99
|
||||
|
BIN
tests/inpainting/coyote-input.webp
Normal file
BIN
tests/inpainting/coyote-input.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
BIN
tests/inpainting/coyote-mask.webp
Normal file
BIN
tests/inpainting/coyote-mask.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.5 KiB |
30
tests/inpainting/original.json
Normal file
30
tests/inpainting/original.json
Normal file
@ -0,0 +1,30 @@
|
||||
{
|
||||
"model": "stable diffusion",
|
||||
"model_id": null,
|
||||
"model_hash": "cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516",
|
||||
"app_id": "invoke-ai/InvokeAI",
|
||||
"app_version": "v2.2.3",
|
||||
"image": {
|
||||
"height": 512,
|
||||
"steps": 50,
|
||||
"facetool": "gfpgan",
|
||||
"facetool_strength": 0,
|
||||
"seed": 1948097268,
|
||||
"perlin": 0,
|
||||
"init_mask": null,
|
||||
"width": 512,
|
||||
"upscale": null,
|
||||
"cfg_scale": 7.5,
|
||||
"prompt": [
|
||||
{
|
||||
"prompt": "a coyote, deep palette knife oil painting, red aloe, plants, desert landscape, award winning",
|
||||
"weight": 1
|
||||
}
|
||||
],
|
||||
"threshold": 0,
|
||||
"postprocessing": null,
|
||||
"sampler": "k_lms",
|
||||
"variations": [],
|
||||
"type": "txt2img"
|
||||
}
|
||||
}
|
301
tests/test_textual_inversion.py
Normal file
301
tests/test_textual_inversion.py
Normal file
@ -0,0 +1,301 @@
|
||||
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
KNOWN_WORDS = ['a', 'b', 'c']
|
||||
KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
|
||||
UNKNOWN_WORDS = ['d', 'e', 'f']
|
||||
|
||||
class DummyEmbeddingsList(list):
|
||||
def __getattr__(self, name):
|
||||
if name == 'num_embeddings':
|
||||
return len(self)
|
||||
elif name == 'weight':
|
||||
return self
|
||||
elif name == 'data':
|
||||
return self
|
||||
|
||||
def make_dummy_embedding():
|
||||
return torch.randn([768])
|
||||
|
||||
class DummyTransformer:
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))])
|
||||
|
||||
def resize_token_embeddings(self, new_size=None):
|
||||
if new_size is None:
|
||||
return self.embeddings
|
||||
else:
|
||||
while len(self.embeddings) > new_size:
|
||||
self.embeddings.pop(-1)
|
||||
while len(self.embeddings) < new_size:
|
||||
self.embeddings.append(make_dummy_embedding())
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
class DummyTokenizer():
|
||||
def __init__(self):
|
||||
self.tokens = KNOWN_WORDS.copy()
|
||||
self.bos_token_id = 49406 # these are what the real CLIPTokenizer has
|
||||
self.eos_token_id = 49407
|
||||
self.pad_token_id = 49407
|
||||
self.unk_token_id = 49407
|
||||
|
||||
def convert_tokens_to_ids(self, token_str):
|
||||
try:
|
||||
return self.tokens.index(token_str)
|
||||
except ValueError:
|
||||
return self.unk_token_id
|
||||
|
||||
def add_tokens(self, token_str):
|
||||
if token_str in self.tokens:
|
||||
return 0
|
||||
self.tokens.append(token_str)
|
||||
return 1
|
||||
|
||||
|
||||
class DummyClipEmbedder:
|
||||
def __init__(self):
|
||||
self.max_length = 77
|
||||
self.transformer = DummyTransformer()
|
||||
self.tokenizer = DummyTokenizer()
|
||||
self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32)
|
||||
|
||||
def position_embedding(self, indices: Union[list,torch.Tensor]):
|
||||
if type(indices) is list:
|
||||
indices = torch.tensor(indices, dtype=int)
|
||||
return torch.index_select(self.position_embeddings_tensor, 0, indices)
|
||||
|
||||
|
||||
def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool:
|
||||
return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices))
|
||||
|
||||
|
||||
def make_dummy_textual_inversion_manager():
|
||||
return TextualInversionManager(
|
||||
tokenizer=DummyTokenizer(),
|
||||
text_encoder=DummyTransformer()
|
||||
)
|
||||
|
||||
class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_construction(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
|
||||
def test_add_embedding_for_known_token(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
test_embedding = torch.randn([1, 768])
|
||||
test_embedding_name = KNOWN_WORDS[0]
|
||||
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||
|
||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
|
||||
ti = tim._add_textual_inversion(test_embedding_name, test_embedding)
|
||||
self.assertEqual(ti.trigger_token_id, 0)
|
||||
|
||||
|
||||
# check adding 'test' did not create a new word
|
||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
self.assertEqual(pre_embeddings_count, embeddings_count)
|
||||
|
||||
# check it was added
|
||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||
self.assertIsNotNone(textual_inversion)
|
||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||
self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id)
|
||||
|
||||
def test_add_embedding_for_unknown_token(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
test_embedding_1 = torch.randn([1, 768])
|
||||
test_embedding_name_1 = UNKNOWN_WORDS[0]
|
||||
|
||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
|
||||
added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id
|
||||
# new token id should get added on the end
|
||||
self.assertEqual(added_token_id_1, len(KNOWN_WORDS))
|
||||
|
||||
# check adding did create a new word
|
||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||
|
||||
# check it was added
|
||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
|
||||
self.assertIsNotNone(textual_inversion)
|
||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
||||
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
|
||||
|
||||
# add another one
|
||||
test_embedding_2 = torch.randn([1, 768])
|
||||
test_embedding_name_2 = UNKNOWN_WORDS[1]
|
||||
|
||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
|
||||
added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id
|
||||
self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1)
|
||||
|
||||
# check adding did create a new word
|
||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||
|
||||
# check it was added
|
||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2))
|
||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2)
|
||||
self.assertIsNotNone(textual_inversion)
|
||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2))
|
||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2)
|
||||
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2)
|
||||
|
||||
# check the old one is still there
|
||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
|
||||
self.assertIsNotNone(textual_inversion)
|
||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
||||
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
|
||||
|
||||
|
||||
def test_pad_raises_on_eos_bos(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \
|
||||
[KNOWN_WORDS_TOKEN_IDS] + \
|
||||
[tim.tokenizer.eos_token_id]
|
||||
with self.assertRaises(ValueError):
|
||||
tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos)
|
||||
|
||||
def test_pad_tokens_list_vector_length_1(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
test_embedding_1v = torch.randn([1, 768])
|
||||
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
|
||||
test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id
|
||||
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
|
||||
|
||||
# at the end
|
||||
prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append)
|
||||
self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids)
|
||||
|
||||
# at the start
|
||||
prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend)
|
||||
self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids)
|
||||
|
||||
# in the middle
|
||||
prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert)
|
||||
self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids)
|
||||
|
||||
def test_pad_tokens_list_vector_length_2(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
test_embedding_2v = torch.randn([2, 768])
|
||||
test_embedding_2v_token = "<inversion-trigger-vector-length-2>"
|
||||
test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id
|
||||
test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids
|
||||
self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS))
|
||||
|
||||
# at the end
|
||||
prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append)
|
||||
self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
# at the start
|
||||
prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend)
|
||||
self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids)
|
||||
self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
# in the middle
|
||||
prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert)
|
||||
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||
|
||||
def test_pad_tokens_list_vector_length_8(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
test_embedding_8v = torch.randn([8, 768])
|
||||
test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
|
||||
test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id
|
||||
test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids
|
||||
self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
|
||||
|
||||
# at the end
|
||||
prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append)
|
||||
self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
# at the start
|
||||
prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend)
|
||||
self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
|
||||
self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
# in the middle
|
||||
prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert)
|
||||
self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||
|
||||
|
||||
def test_deferred_loading(self):
|
||||
tim = make_dummy_textual_inversion_manager()
|
||||
test_embedding = torch.randn([1, 768])
|
||||
test_embedding_name = UNKNOWN_WORDS[0]
|
||||
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||
|
||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
|
||||
ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True)
|
||||
self.assertIsNone(ti.trigger_token_id)
|
||||
|
||||
# check that a new word is not yet created
|
||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
self.assertEqual(pre_embeddings_count, embeddings_count)
|
||||
|
||||
# check it was added
|
||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||
self.assertIsNotNone(textual_inversion)
|
||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||
self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id)
|
||||
|
||||
# check it lazy-loads
|
||||
prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]])
|
||||
tim.create_deferred_token_ids_for_any_trigger_terms(prompt)
|
||||
|
||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||
|
||||
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||
self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS))
|
Loading…
Reference in New Issue
Block a user