Compare commits

..

1 Commits

Author SHA1 Message Date
c953051eae Remove Training references from invoke script 2024-02-29 22:16:44 -05:00
153 changed files with 5522 additions and 4553 deletions

View File

@ -1,33 +1,33 @@
name: install frontend dependencies name: Install frontend dependencies
description: Installs frontend dependencies with pnpm, with caching description: Installs frontend dependencies with pnpm, with caching
runs: runs:
using: 'composite' using: 'composite'
steps: steps:
- name: setup node 18 - name: Setup Node 18
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: '18' node-version: '18'
- name: setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@v2 uses: pnpm/action-setup@v2
with: with:
version: 8 version: 8
run_install: false run_install: false
- name: get pnpm store directory - name: Get pnpm store directory
shell: bash shell: bash
run: | run: |
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
- name: setup cache - uses: actions/cache@v3
uses: actions/cache@v4 name: Setup pnpm cache
with: with:
path: ${{ env.STORE_PATH }} path: ${{ env.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }} key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: | restore-keys: |
${{ runner.os }}-pnpm-store- ${{ runner.os }}-pnpm-store-
- name: install frontend dependencies - name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile run: pnpm install --prefer-frozen-lockfile
shell: bash shell: bash
working-directory: invokeai/frontend/web working-directory: invokeai/frontend/web

View File

@ -0,0 +1,11 @@
name: Install python dependencies
description: Install python dependencies with pip, with caching
runs:
using: 'composite'
steps:
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml

28
.github/pr_labels.yml vendored
View File

@ -1,59 +1,59 @@
root: Root:
- changed-files: - changed-files:
- any-glob-to-any-file: '*' - any-glob-to-any-file: '*'
python-deps: PythonDeps:
- changed-files: - changed-files:
- any-glob-to-any-file: 'pyproject.toml' - any-glob-to-any-file: 'pyproject.toml'
python: Python:
- changed-files: - changed-files:
- all-globs-to-any-file: - all-globs-to-any-file:
- 'invokeai/**' - 'invokeai/**'
- '!invokeai/frontend/web/**' - '!invokeai/frontend/web/**'
python-tests: PythonTests:
- changed-files: - changed-files:
- any-glob-to-any-file: 'tests/**' - any-glob-to-any-file: 'tests/**'
ci-cd: CICD:
- changed-files: - changed-files:
- any-glob-to-any-file: .github/** - any-glob-to-any-file: .github/**
docker: Docker:
- changed-files: - changed-files:
- any-glob-to-any-file: docker/** - any-glob-to-any-file: docker/**
installer: Installer:
- changed-files: - changed-files:
- any-glob-to-any-file: installer/** - any-glob-to-any-file: installer/**
docs: Documentation:
- changed-files: - changed-files:
- any-glob-to-any-file: docs/** - any-glob-to-any-file: docs/**
invocations: Invocations:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/app/invocations/**' - any-glob-to-any-file: 'invokeai/app/invocations/**'
backend: Backend:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/backend/**' - any-glob-to-any-file: 'invokeai/backend/**'
api: Api:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/app/api/**' - any-glob-to-any-file: 'invokeai/app/api/**'
services: Services:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/app/services/**' - any-glob-to-any-file: 'invokeai/app/services/**'
frontend-deps: FrontendDeps:
- changed-files: - changed-files:
- any-glob-to-any-file: - any-glob-to-any-file:
- '**/*/package.json' - '**/*/package.json'
- '**/*/pnpm-lock.yaml' - '**/*/pnpm-lock.yaml'
frontend: Frontend:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/frontend/web/**' - any-glob-to-any-file: 'invokeai/frontend/web/**'

View File

@ -1,45 +0,0 @@
# Builds and uploads the installer and python build artifacts.
name: build installer
on:
workflow_dispatch:
workflow_call:
jobs:
build-installer:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <2 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install pypa/build
run: pip install --upgrade build
- name: setup frontend
uses: ./.github/actions/install-frontend-deps
- name: create installer
id: create_installer
run: ./create_installer.sh
working-directory: installer
- name: upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}

43
.github/workflows/check-frontend.yml vendored Normal file
View File

@ -0,0 +1,43 @@
# This workflow runs the frontend code quality checks.
#
# It may be triggered via dispatch, or by another workflow.
name: 'Check: frontend'
on:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
check-frontend:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: Set up frontend
uses: ./.github/actions/install-frontend-deps
- name: Run tsc check
run: 'pnpm run lint:tsc'
shell: bash
- name: Run dpdm check
run: 'pnpm run lint:dpdm'
shell: bash
- name: Run eslint check
run: 'pnpm run lint:eslint'
shell: bash
- name: Run prettier check
run: 'pnpm run lint:prettier'
shell: bash
- name: Run knip check
run: 'pnpm run lint:knip'
shell: bash

72
.github/workflows/check-pytest.yml vendored Normal file
View File

@ -0,0 +1,72 @@
# This workflow runs pytest on the codebase in a matrix of platforms.
#
# It may be triggered via dispatch, or by another workflow.
name: 'Check: pytest'
on:
workflow_dispatch:
workflow_call:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
timeout-minutes: 30 # expected run time: <10 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- uses: actions/checkout@v4
- name: set test prompt to main branch validation
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install invokeai
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install
--editable=".[test]"
- name: run pytest
id: run-pytest
run: pytest

33
.github/workflows/check-python.yml vendored Normal file
View File

@ -0,0 +1,33 @@
# This workflow runs the python code quality checks.
#
# It may be triggered via dispatch, or by another workflow.
#
# TODO: Add mypy or pyright to the checks.
name: 'Check: python'
on:
workflow_dispatch:
workflow_call:
jobs:
check-backend:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- uses: actions/checkout@v4
- name: Install python dependencies
uses: ./.github/actions/install-python-deps
- name: Install ruff
run: pip install ruff
shell: bash
- name: Ruff check
run: ruff check --output-format=github .
shell: bash
- name: Ruff format
run: ruff format --check .
shell: bash

View File

@ -1,68 +0,0 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# When manually triggered or when called from another workflow, always runs the checks.
name: 'frontend checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-checks:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:knip'
shell: bash

View File

@ -1,48 +0,0 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'frontend tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-tests:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm test:no-watch'
shell: bash

View File

@ -1,4 +1,4 @@
name: 'label PRs' name: "Pull Request Labeler"
on: on:
- pull_request_target - pull_request_target
@ -9,10 +9,8 @@ jobs:
pull-requests: write pull-requests: write
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- uses: actions/labeler@v5
- name: label PRs
uses: actions/labeler@v5
with: with:
configuration-path: .github/pr_labels.yml configuration-path: .github/pr_labels.yml

View File

@ -21,29 +21,18 @@ jobs:
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI' SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps: steps:
- name: checkout - uses: actions/checkout@v4
uses: actions/checkout@v4 - uses: actions/setup-python@v5
- name: setup python
uses: actions/setup-python@v5
with: with:
python-version: '3.10' python-version: '3.10'
cache: pip cache: pip
cache-dependency-path: pyproject.toml cache-dependency-path: pyproject.toml
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- name: set cache id - uses: actions/cache@v4
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- name: use cache
uses: actions/cache@v4
with: with:
key: mkdocs-material-${{ env.cache_id }} key: mkdocs-material-${{ env.cache_id }}
path: .cache path: .cache
restore-keys: | restore-keys: |
mkdocs-material- mkdocs-material-
- run: python -m pip install ".[docs]"
- name: install dependencies - run: mkdocs gh-deploy --force
run: python -m pip install ".[docs]"
- name: build & deploy
run: mkdocs gh-deploy --force

View File

@ -0,0 +1,39 @@
# This workflow runs of `check-frontend.yml` on push or pull request.
#
# The actual checks are in a separate workflow to support simpler workflow
# composition without awkward or complicated conditionals.
name: 'On change: run check-frontend'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
jobs:
check-changed-frontend-files:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
frontend_any_changed: ${{ steps.changed-files.outputs.frontend_any_changed }}
steps:
- uses: actions/checkout@v4
- name: Check for changed frontend files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
run-check-frontend:
needs: check-changed-frontend-files
if: ${{ needs.check-changed-frontend-files.outputs.frontend_any_changed == 'true' }}
uses: ./.github/workflows/check-frontend.yml

View File

@ -0,0 +1,42 @@
# This workflow runs of `check-python.yml` on push or pull request.
#
# The actual checks are in a separate workflow to support simpler workflow
# composition without awkward or complicated conditionals.
name: 'On change: run check-python'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
jobs:
check-changed-python-files:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }}
steps:
- uses: actions/checkout@v4
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
run-check-python:
needs: check-changed-python-files
if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }}
uses: ./.github/workflows/check-python.yml

42
.github/workflows/on-change-pytest.yml vendored Normal file
View File

@ -0,0 +1,42 @@
# This workflow runs of `check-pytest.yml` on push or pull request.
#
# The actual checks are in a separate workflow to support simpler workflow
# composition without awkward or complicated conditionals.
name: 'On change: run pytest'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
jobs:
check-changed-python-files:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }}
steps:
- uses: actions/checkout@v4
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
run-pytest:
needs: check-changed-python-files
if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }}
uses: ./.github/workflows/check-pytest.yml

View File

@ -1,64 +0,0 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# When manually triggered or called from another workflow, always runs the tests.
#
# TODO: Add mypy or pyright to the checks.
name: 'python checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
jobs:
python-checks:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff format --check .
shell: bash

View File

@ -1,94 +0,0 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'python tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
runs-on: ${{ matrix.os }}
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pytest

View File

@ -1,96 +1,103 @@
# Main release workflow. Triggered on tag push or manual trigger. name: Release
#
# - Runs all code checks and tests
# - Verifies the app version matches the tag version.
# - Builds the installer and build, uploading them as artifacts.
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
#
# See docs/RELEASE.md for more information on the release process.
name: release
on: on:
push: push:
tags: tags:
- 'v*' - 'v*'
workflow_dispatch: workflow_dispatch:
inputs:
skip_code_checks:
description: 'Skip code checks'
required: true
default: true
type: boolean
jobs: jobs:
check-version: check-version:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: checkout - uses: actions/checkout@v4
uses: actions/checkout@v4
- name: check python version - uses: samuelcolvin/check-python-version@v4
uses: samuelcolvin/check-python-version@v4
id: check-python-version id: check-python-version
with: with:
version_file_path: invokeai/version/invokeai_version.py version_file_path: invokeai/version/invokeai_version.py
frontend-checks: check-frontend:
uses: ./.github/workflows/frontend-checks.yml if: github.event.inputs.skip_code_checks != 'true'
uses: ./.github/workflows/check-frontend.yml
frontend-tests: check-python:
uses: ./.github/workflows/frontend-tests.yml if: github.event.inputs.skip_code_checks != 'true'
uses: ./.github/workflows/check-python.yml
python-checks: check-pytest:
uses: ./.github/workflows/python-checks.yml if: github.event.inputs.skip_code_checks != 'true'
uses: ./.github/workflows/check-pytest.yml
python-tests:
uses: ./.github/workflows/python-tests.yml
build: build:
uses: ./.github/workflows/build-installer.yml runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install python dependencies
uses: ./.github/actions/install-python-deps
- name: Install pypa/build
run: pip install --upgrade build
- name: Setup frontend
uses: ./.github/actions/install-frontend-deps
- name: Run create_installer.sh
id: create_installer
run: ./create_installer.sh --skip_frontend_checks
working-directory: installer
- name: Upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: Upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
publish-testpypi: publish-testpypi:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min needs: [check-version, check-frontend, check-python, check-pytest, build]
needs: if: github.event_name != 'workflow_dispatch'
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment: environment:
name: testpypi name: testpypi
url: https://test.pypi.org/p/invokeai url: https://test.pypi.org/p/invokeai
steps: steps:
- name: download distribution from build job - name: Download distribution from build job
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
name: dist name: dist
path: dist/ path: dist/
- name: publish distribution to TestPyPI - name: Publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
repository-url: https://test.pypi.org/legacy/ repository-url: https://test.pypi.org/legacy/
publish-pypi: publish-pypi:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min needs: [check-version, check-frontend, check-python, check-pytest, build]
needs: if: github.event_name != 'workflow_dispatch'
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment: environment:
name: pypi name: pypi
url: https://pypi.org/p/invokeai url: https://pypi.org/p/invokeai
steps: steps:
- name: download distribution from build job - name: Download distribution from build job
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
name: dist name: dist
path: dist/ path: dist/
- name: publish distribution to PyPI - name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1

View File

@ -23,13 +23,13 @@ It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if
Run `make tag-release` to tag the current commit and kick off the workflow. Run `make tag-release` to tag the current commit and kick off the workflow.
The release may also be dispatched [manually]. The release may also be run [manually].
### Workflow Jobs and Process ### Workflow Jobs and Process
The workflow consists of a number of concurrently-run jobs, and two final publish jobs. The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs require manual approval and are only run if the other jobs succeed. The publish jobs run if the 5 concurrent jobs all succeed and if/when the publish jobs are approved.
#### `check-version` Job #### `check-version` Job
@ -43,16 +43,17 @@ This job uses [samuelcolvin/check-python-version].
#### Check and Test Jobs #### Check and Test Jobs
- **`python-tests`**: runs `pytest` on matrix of platforms This is our test suite.
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest` - **`check-pytest`**: runs `pytest` on matrix of platforms
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports) - **`check-python`**: runs `ruff` (format and lint)
- **`check-frontend`**: runs `prettier` (format), `eslint` (lint), `madge` (circular refs) and `tsc` (static type check)
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job. > **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image. > **TODO** We should add an end-to-end test job that generates an image.
#### `build-installer` Job #### `build` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts: This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
@ -61,7 +62,7 @@ This sets up both python and frontend dependencies and builds the python package
#### Sanity Check & Smoke Test #### Sanity Check & Smoke Test
At this point, the release workflow pauses as the remaining publish jobs require approval. At this point, the release workflow pauses (the remaining jobs all require approval).
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates. A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
@ -69,7 +70,7 @@ A maintainer should go to the **Summary** tab of the workflow, download the inst
#### PyPI Publish Jobs #### PyPI Publish Jobs
The publish jobs will run if any of the previous jobs fail. The publish jobs will skip if any of the previous jobs skip or fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI. They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
@ -118,17 +119,13 @@ Once the release is published to PyPI, it's time to publish the GitHub release.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up. > **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manual Build ## Manually Running the Release Workflow
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag. The release workflow can be run manually. This is useful to get an installer build and test it out without needing to push a tag.
No checks are run, it just builds. When run this way, you'll see **Skip code checks** checkbox. This allows the workflow to run without the time-consuming 3 code quality check jobs.
## Manual Release The publish jobs will skip if the workflow was run manually.
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases [InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/ [PyPI]: https://pypi.org/
@ -139,4 +136,4 @@ This functionality is available as a fallback in case something goes wonky. Typi
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment [GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/ [trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version [samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
[manually]: #manual-release [manually]: #manually-running-the-release-workflow

View File

@ -32,6 +32,7 @@ model. These are the:
Responsible for loading a model from disk Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference. into RAM and VRAM and getting it ready for inference.
## Location of the Code ## Location of the Code
The four main services can be found in The four main services can be found in
@ -66,17 +67,19 @@ provides the following fields:
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator | | `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with | | `base_model` | BaseModelType | The base model that the model is compatible with |
| `path` | str | Location of model on disk | | `path` | str | Location of model on disk |
| `hash` | str | Hash of the model | | `original_hash` | str | Hash of the model when it was first installed |
| `current_hash` | str | Most recent hash of the model's contents |
| `description` | str | Human-readable description of the model (optional) | | `description` | str | Human-readable description of the model (optional) |
| `source` | str | Model's source URL or repo id (optional) | | `source` | str | Model's source URL or repo id (optional) |
The `key` is a unique 32-character random ID which was generated at The `key` is a unique 32-character random ID which was generated at
install time. The `hash` field stores a hash of the model's install time. The `original_hash` field stores a hash of the model's
contents at install time obtained by sampling several parts of the contents at install time obtained by sampling several parts of the
model's files using the `imohash` library. Over the course of the model's files using the `imohash` library. Over the course of the
model's lifetime it may be transformed in various ways, such as model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a changing its precision or converting it from a .safetensors to a
diffusers model. diffusers model. When this happens, `original_hash` is unchanged, but
`current_hash` is updated to indicate the current contents.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that `ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also are defined in `invokeai.backend.model_manager.config`. They are also
@ -91,6 +94,7 @@ The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file. `invokeai.yaml` file.
### CheckpointConfig ### CheckpointConfig
This adds support for checkpoint configurations, and adds the This adds support for checkpoint configurations, and adds the
@ -224,9 +228,9 @@ The way it works is as follows:
1. Retrieve the value of the `model_config_db` option from the user's 1. Retrieve the value of the `model_config_db` option from the user's
`invokeai.yaml` config file. `invokeai.yaml` config file.
2. If `model_config_db` is `auto` (the default), then: 2. If `model_config_db` is `auto` (the default), then:
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object - Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
opened on the passed connection and lock. opened on the passed connection and lock.
* Open up a new connection to `databases/invokeai.db` if `conn` - Open up a new connection to `databases/invokeai.db` if `conn`
and/or `lock` are missing (see note below). and/or `lock` are missing (see note below).
3. If `model_config_db` is a Path, then use `from_db_file` 3. If `model_config_db` is a Path, then use `from_db_file`
to return the appropriate type of ModelRecordService. to return the appropriate type of ModelRecordService.
@ -251,7 +255,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
Configurations can be retrieved in several ways. Configurations can be retrieved in several ways.
#### get_model(key) -> AnyModelConfig #### get_model(key) -> AnyModelConfig:
The basic functionality is to call the record store object's The basic functionality is to call the record store object's
`get_model()` method with the desired model's unique key. It returns `get_model()` method with the desired model's unique key. It returns
@ -268,28 +272,28 @@ print(model_conf.path)
If the key is unrecognized, this call raises an If the key is unrecognized, this call raises an
`UnknownModelException`. `UnknownModelException`.
#### exists(key) -> AnyModelConfig #### exists(key) -> AnyModelConfig:
Returns True if a model with the given key exists in the databsae. Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig #### search_by_path(path) -> AnyModelConfig:
Returns the configuration of the model whose path is `path`. The path Returns the configuration of the model whose path is `path`. The path
is matched using a simple string comparison and won't correctly match is matched using a simple string comparison and won't correctly match
models referred to by different paths (e.g. using symbolic links). models referred to by different paths (e.g. using symbolic links).
#### search_by_name(name, base, type) -> List[AnyModelConfig] #### search_by_name(name, base, type) -> List[AnyModelConfig]:
This method searches for models that match some combination of `name`, This method searches for models that match some combination of `name`,
`BaseType` and `ModelType`. Calling without any arguments will return `BaseType` and `ModelType`. Calling without any arguments will return
all the models in the database. all the models in the database.
#### all_models() -> List[AnyModelConfig] #### all_models() -> List[AnyModelConfig]:
Return all the model configs in the database. Exactly equivalent to Return all the model configs in the database. Exactly equivalent to
calling `search_by_name()` with no arguments. calling `search_by_name()` with no arguments.
#### search_by_tag(tags) -> List[AnyModelConfig] #### search_by_tag(tags) -> List[AnyModelConfig]:
`tags` is a list of strings. This method returns a list of model `tags` is a list of strings. This method returns a list of model
configs that contain all of the given tags. Examples: configs that contain all of the given tags. Examples:
@ -308,11 +312,11 @@ commercializable_models = [x for x in store.all_models() \
if x.license.contains('allowCommercialUse=Sell')] if x.license.contains('allowCommercialUse=Sell')]
``` ```
#### version() -> str #### version() -> str:
Returns the version of the database, currently at `3.2` Returns the version of the database, currently at `3.2`
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase #### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
This method exists to ease the transition from the previous version of This method exists to ease the transition from the previous version of
the model manager, in which `get_model()` took the three arguments the model manager, in which `get_model()` took the three arguments
@ -333,7 +337,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config Several methods allow you to create and update stored model config
records. records.
#### add_model(key, config) -> AnyModelConfig #### add_model(key, config) -> AnyModelConfig:
Given a key and a configuration, this will add the model's Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of configuration record to the database. `config` can either be a subclass of
@ -348,7 +352,7 @@ model with the same key is already in the database, or an
`InvalidModelConfigException` if a dict was passed and Pydantic `InvalidModelConfigException` if a dict was passed and Pydantic
experienced a parse or validation error. experienced a parse or validation error.
### update_model(key, config) -> AnyModelConfig ### update_model(key, config) -> AnyModelConfig:
Given a key and a configuration, this will update the model Given a key and a configuration, this will update the model
configuration record in the database. `config` can be either a configuration record in the database. `config` can be either a
@ -366,31 +370,31 @@ The `ModelInstallService` class implements the
shop for all your model install needs. It provides the following shop for all your model install needs. It provides the following
functionality: functionality:
* Registering a model config record for a model already located on the - Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path. local filesystem, without moving it or changing its path.
* Installing a model alreadiy located on the local filesystem, by - Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir` `models` folder (or wherever config parameter `models_dir`
specifies). specifies).
* Probing of models to determine their type, base type and other key - Probing of models to determine their type, base type and other key
information. information.
* Interface with the InvokeAI event bus to provide status updates on - Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process. the download, installation and registration process.
* Downloading a model from an arbitrary URL and installing it in - Downloading a model from an arbitrary URL and installing it in
`models_dir`. `models_dir`.
* Special handling for Civitai model URLs which allow the user to - Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link paste in a model page's URL or download link
* Special handling for HuggingFace repo_ids to recursively download - Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative the contents of the repository, paying attention to alternative
variants such as fp16. variants such as fp16.
* Saving tags and other metadata about the model into the invokeai database - Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information, when fetching from a repo that provides that type of information,
(currently only Civitai and HuggingFace). (currently only Civitai and HuggingFace).
@ -439,6 +443,7 @@ required parameters:
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object | | `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) | |`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
Once initialized, the installer will provide the following methods: Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token]) #### install_job = installer.heuristic_import(source, [config], [access_token])
@ -452,12 +457,12 @@ The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) 1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) 2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats: 3. A HuggingFace repo_id with any of the following formats:
* `model/name` -- entire model - `model/name` -- entire model
* `model/name:fp32` -- entire model, using the fp32 variant - `model/name:fp32` -- entire model, using the fp32 variant
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant - `model/name:fp16:vae` -- vae submodel, using the fp16 variant
* `model/name::vae` -- vae submodel, using default precision - `model/name::vae` -- vae submodel, using default precision
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant - `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
* `model/name::path/to/model.safetensors` -- an individual model file, default variant - `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files. repo, you can download and install arbitrary models files.
@ -561,6 +566,7 @@ details.
This is used for a model that is located on a locally-accessible Posix This is used for a model that is located on a locally-accessible Posix
filesystem, such as a local disk or networked fileshare. filesystem, such as a local disk or networked fileshare.
| **Argument** | **Type** | **Default** | **Description** | | **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------| |------------------|------------------------------|-------------|-------------------------------------------|
| `path` | str | Path | None | Path to the model file or directory | | `path` | str | Path | None | Path to the model file or directory |
@ -619,6 +625,7 @@ HuggingFace has the most complicated `ModelSource` structure:
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. | | `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. | | `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`. The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
The `variant` is one of the various diffusers formats that HuggingFace The `variant` is one of the various diffusers formats that HuggingFace
@ -654,6 +661,7 @@ in. To download these files, you must provide an
`HfFolder.get_token()` will be called to fill it in with the cached `HfFolder.get_token()` will be called to fill it in with the cached
one. one.
#### Monitoring the install job process #### Monitoring the install job process
When you create an install job with `import_model()`, it launches the When you create an install job with `import_model()`, it launches the
@ -674,6 +682,7 @@ The `ModelInstallJob` class has the following structure:
| `error_type` | `str` | Name of the exception that led to an error status | | `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error | | `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and as an event of type `EventServiceBase.model_event`, a timestamp and
@ -693,13 +702,14 @@ following keys:
| `total_bytes` | int | Total size of all the files that make up the model | | `total_bytes` | int | Total size of all the files that make up the model |
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model | | `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
The parts is a list of dictionaries that give information on each of The parts is a list of dictionaries that give information on each of
the components pieces of the download. The dictionary's keys are the components pieces of the download. The dictionary's keys are
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to `source`, `local_path`, `bytes` and `total_bytes`, and correspond to
the like-named keys in the main event. the like-named keys in the main event.
Note that downloading events will not be issued for local models, and Note that downloading events will not be issued for local models, and
that downloading events occur _before_ the running event. that downloading events occur *before* the running event.
##### `model_install_running` ##### `model_install_running`
@ -742,6 +752,7 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
and `cancelled`, as well as `in_terminal_state`. The last will return and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states. True if the job is in the complete, errored or cancelled states.
#### Model configuration and probing #### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe` The install service uses the `invokeai.backend.model_manager.probe`
@ -851,6 +862,7 @@ This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy. they are inside or outside the InvokeAI models hierarchy.
#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) #### path = installer.download_and_cache(remote_source, [access_token], [timeout])
This utility routine will download the model file located at source, This utility routine will download the model file located at source,
@ -1028,11 +1040,11 @@ While a job is being downloaded, the queue will emit events at
periodic intervals. A typical series of events during a successful periodic intervals. A typical series of events during a successful
download session will look like this: download session will look like this:
* enqueued - enqueued
* running - running
* running - running
* running - running
* completed - completed
There will be a single enqueued event, followed by one or more running There will be a single enqueued event, followed by one or more running
events, and finally one `completed`, `error` or `cancelled` events, and finally one `completed`, `error` or `cancelled`
@ -1041,12 +1053,12 @@ events.
It is possible for a caller to pause download temporarily, in which It is possible for a caller to pause download temporarily, in which
case the events may look something like this: case the events may look something like this:
* enqueued - enqueued
* running - running
* running - running
* paused - paused
* running - running
* completed - completed
The download queue logs when downloads start and end (unless `quiet` The download queue logs when downloads start and end (unless `quiet`
is set to True at initialization time) but doesn't log any progress is set to True at initialization time) but doesn't log any progress
@ -1175,6 +1187,7 @@ and is equivalent to manually specifying a destination of
Here is the full list of arguments that can be provided to Here is the full list of arguments that can be provided to
`create_download_job()`: `create_download_job()`:
| **Argument** | **Type** | **Default** | **Description** | | **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------| |------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source | | `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
@ -1262,7 +1275,7 @@ for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the "type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`. model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage ### Example Usage:
``` ```
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
@ -1315,6 +1328,7 @@ This is the common base class for metadata:
| `author` | str | Model's author | | `author` | str | Model's author |
| `tags` | Set[str] | Model tags | | `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected the metadata version is read-only. However, enforcing this is expected
@ -1334,6 +1348,7 @@ This descends from `ModelMetadataBase` and adds the following fields:
| `last_modified`| datetime | Date of last commit of this model to the repo | | `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo | | `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata` #### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields: This descends from `ModelMetadataBase` and adds the following fields:
@ -1400,6 +1415,7 @@ testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods: repo-specific fetching methods:
#### HuggingFaceMetadataFetch #### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a This overrides its base class `from_json()` method to return a
@ -1418,6 +1434,7 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`. than an `AnyModelRepoMetadata`.
### Metadata Storage ### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model The `ModelMetadataStore` provides a simple facility to store model
@ -1550,6 +1567,7 @@ The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as record returned by the model record `get_model()` method, as well as
the in-memory loaded model: the in-memory loaded model:
| **Attribute Name** | **Type** | **Description** | | **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------| |----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | | `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
@ -1563,6 +1581,7 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious. models. The others are obvious.
`LoadedModel` acts as a context manager. The context loads the model `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns in the execution device for the duration of the context, and returns
@ -1576,9 +1595,9 @@ with model_info as vae:
`get_model_by_key()` may raise any of the following exceptions: `get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database - `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path - `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model - `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events ### Emitting model loading events
@ -1705,7 +1724,6 @@ object, or in `context.services.model_manager` from within an
invocation. invocation.
In the examples below, we have retrieved the manager using: In the examples below, we have retrieved the manager using:
``` ```
mm = ApiDependencies.invoker.services.model_manager mm = ApiDependencies.invoker.services.model_manager
``` ```

View File

@ -19,8 +19,6 @@ their descriptions.
| Conditioning Primitive | A conditioning tensor primitive value | | Conditioning Primitive | A conditioning tensor primitive value |
| Content Shuffle Processor | Applies content shuffle processing to image | | Content Shuffle Processor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes | | ControlNet | Collects ControlNet info to pass to other nodes |
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
| Denoise Latents | Denoises noisy latents to decodable images | | Denoise Latents | Denoises noisy latents to decodable images |
| Divide Integers | Divides two numbers | | Divide Integers | Divides two numbers |
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator | | Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |

View File

@ -9,7 +9,7 @@ set INVOKEAI_ROOT=.
:start :start
echo Desired action: echo Desired action:
echo 1. Generate images with the browser-based interface echo 1. Generate images with the browser-based interface
echo 2. Run textual inversion training echo 2. Invoke Model Training
echo 3. Merge models (diffusers type only) echo 3. Merge models (diffusers type only)
echo 4. Download and install models echo 4. Download and install models
echo 5. Change InvokeAI startup options echo 5. Change InvokeAI startup options
@ -25,8 +25,7 @@ IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI.. echo Starting the InvokeAI browser-based UI..
python .venv\Scripts\invokeai-web.exe %* python .venv\Scripts\invokeai-web.exe %*
) ELSE IF /I "%choice%" == "2" ( ) ELSE IF /I "%choice%" == "2" (
echo Starting textual inversion training.. echo To use Invoke Training for LoRA, TI, and more - Visit https://github.com/invoke-ai/invoke-training
python .venv\Scripts\invokeai-ti.exe --gui
) ELSE IF /I "%choice%" == "3" ( ) ELSE IF /I "%choice%" == "3" (
echo Starting model merging script.. echo Starting model merging script..
python .venv\Scripts\invokeai-merge.exe --gui python .venv\Scripts\invokeai-merge.exe --gui

View File

@ -59,8 +59,7 @@ do_choice() {
;; ;;
2) 2)
clear clear
printf "Textual inversion training\n" printf "To use Invoke Training for LoRA, TI, and more - Visit https://github.com/invoke-ai/invoke-training\n"
invokeai-ti --gui $PARAMS
;; ;;
3) 3)
clear clear
@ -118,7 +117,7 @@ do_choice() {
do_dialog() { do_dialog() {
options=( options=(
1 "Generate images with a browser-based interface" 1 "Generate images with a browser-based interface"
2 "Textual inversion training" 2 "Run Invoke Training"
3 "Merge models (diffusers type only)" 3 "Merge models (diffusers type only)"
4 "Download and install models" 4 "Download and install models"
5 "Change InvokeAI startup options" 5 "Change InvokeAI startup options"
@ -151,7 +150,7 @@ do_line_input() {
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n" printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
printf "What would you like to do?\n" printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n" printf "1: Generate images using the browser-based interface\n"
printf "2: Run textual inversion training\n" printf "2: Run Invoke Training\n"
printf "3: Merge models (diffusers type only)\n" printf "3: Merge models (diffusers type only)\n"
printf "4: Download and install models\n" printf "4: Download and install models\n"
printf "5: Change InvokeAI startup options\n" printf "5: Change InvokeAI startup options\n"

View File

@ -26,6 +26,7 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
@ -92,9 +93,10 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db), model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )

View File

@ -3,7 +3,9 @@
import pathlib import pathlib
import shutil import shutil
from typing import Any, Dict, List, Optional from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -13,10 +15,13 @@ from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException, InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
@ -25,6 +30,8 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -40,6 +47,15 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
############################################################################## ##############################################################################
# These are example inputs and outputs that are used in places where Swagger # These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example. # is unable to generate a correct example.
@ -50,16 +66,19 @@ example_model_config = {
"base": "sd-1", "base": "sd-1",
"type": "main", "type": "main",
"format": "checkpoint", "format": "checkpoint",
"config_path": "string", "config": "string",
"key": "string", "key": "string",
"hash": "string", "original_hash": "string",
"current_hash": "string",
"description": "string", "description": "string",
"source": "string", "source": "string",
"converted_at": 0, "last_modified": 0,
"vae": "string",
"variant": "normal", "variant": "normal",
"prediction_type": "epsilon", "prediction_type": "epsilon",
"repo_variant": "fp16", "repo_variant": "fp16",
"upcast_attention": False, "upcast_attention": False,
"ztsnr_training": False,
} }
example_model_input = { example_model_input = {
@ -68,12 +87,50 @@ example_model_input = {
"base": "sd-1", "base": "sd-1",
"type": "main", "type": "main",
"format": "checkpoint", "format": "checkpoint",
"config_path": "configs/stable-diffusion/v1-inference.yaml", "config": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description", "description": "Model description",
"vae": None, "vae": None,
"variant": "normal", "variant": "normal",
} }
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
############################################################################## ##############################################################################
# ROUTES # ROUTES
############################################################################## ##############################################################################
@ -153,16 +210,48 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary") @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary( async def list_model_summary(
# page: int = Query(default=0, description="The page to get"), page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"), per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data.""" """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results return results
@model_manager_router.get(
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
class FoundModel(BaseModel): class FoundModel(BaseModel):
@ -234,6 +323,19 @@ async def scan_for_models(
return scan_results return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch( @model_manager_router.patch(
"/i/{key}", "/i/{key}",
operation_id="update_model_record", operation_id="update_model_record",
@ -250,13 +352,15 @@ async def scan_for_models(
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)], info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update a model's config.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
model_response: AnyModelConfig = record_store.update_model(key, changes=changes) model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -268,14 +372,14 @@ async def update_model_record(
@model_manager_router.delete( @model_manager_router.delete(
"/i/{key}", "/i/{key}",
operation_id="delete_model", operation_id="del_model_record",
responses={ responses={
204: {"description": "Model deleted successfully"}, 204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
}, },
status_code=204, status_code=204,
) )
async def delete_model( async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response: ) -> Response:
""" """
@ -296,39 +400,42 @@ async def delete_model(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.post( @model_manager_router.post(
# "/i/", "/i/",
# operation_id="add_model_record", operation_id="add_model_record",
# responses={ responses={
# 201: { 201: {
# "description": "The model added successfully", "description": "The model added successfully",
# "content": {"application/json": {"example": example_model_config}}, "content": {"application/json": {"example": example_model_config}},
# }, },
# 409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
# 415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
# }, },
# status_code=201, status_code=201,
# ) )
# async def add_model_record( async def add_model_record(
# config: Annotated[ config: Annotated[
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
# ], ],
# ) -> AnyModelConfig: ) -> AnyModelConfig:
# """Add a model using the configuration information appropriate for its type.""" """Add a model using the configuration information appropriate for its type."""
# logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
# record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
# try: if config.key == "<NOKEY>":
# record_store.add_model(config) config.key = sha1(randbytes(100)).hexdigest()
# except DuplicateModelException as e: logger.info(f"Created model {config.key} for {config.name}")
# logger.error(str(e)) try:
# raise HTTPException(status_code=409, detail=str(e)) record_store.add_model(config.key, config)
# except InvalidModelException as e: except DuplicateModelException as e:
# logger.error(str(e)) logger.error(str(e))
# raise HTTPException(status_code=415) raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# # now fetch it out # now fetch it out
# result: AnyModelConfig = record_store.get_model(config.key) result: AnyModelConfig = record_store.get_model(config.key)
# return result return result
@model_manager_router.post( @model_manager_router.post(
@ -344,7 +451,6 @@ async def delete_model(
) )
async def install_model( async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this? # TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
@ -387,7 +493,6 @@ async def install_model(
source=source, source=source,
config=config, config=config,
access_token=access_token, access_token=access_token,
inplace=bool(inplace),
) )
logger.info(f"Started installation of {source}") logger.info(f"Started installation of {source}")
except UnknownModelException as e: except UnknownModelException as e:
@ -403,10 +508,10 @@ async def install_model(
@model_manager_router.get( @model_manager_router.get(
"/install", "/import",
operation_id="list_model_installs", operation_id="list_model_install_jobs",
) )
async def list_model_installs() -> List[ModelInstallJob]: async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return the list of model install jobs. """Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -420,8 +525,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
* "cancelled" -- Job was cancelled before completion. * "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base Once completed, information about the model such as its size, base
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers, model, type, and metadata can be retrieved from the `config_out`
information on individual files can be retrieved from `download_parts`. field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
See the example and schema below for more information. See the example and schema below for more information.
""" """
@ -430,7 +536,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
@model_manager_router.get( @model_manager_router.get(
"/install/{id}", "/import/{id}",
operation_id="get_model_install_job", operation_id="get_model_install_job",
responses={ responses={
200: {"description": "Success"}, 200: {"description": "Success"},
@ -450,7 +556,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
@model_manager_router.delete( @model_manager_router.delete(
"/install/{id}", "/import/{id}",
operation_id="cancel_model_install_job", operation_id="cancel_model_install_job",
responses={ responses={
201: {"description": "The job was cancelled successfully"}, 201: {"description": "The job was cancelled successfully"},
@ -468,8 +574,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job) installer.cancel_job(job)
@model_manager_router.delete( @model_manager_router.patch(
"/install", "/import",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
204: {"description": "All completed and errored jobs have been pruned"}, 204: {"description": "All completed and errored jobs have been pruned"},
@ -548,8 +654,7 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict # temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name original_name = model_config.name
model_config.name = f"{original_name}.DELETE" model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name) store.update_model(key, config=model_config)
store.update_model(key, changes=changes)
# install the diffusers # install the diffusers
try: try:
@ -558,7 +663,7 @@ async def convert_model(
config={ config={
"name": original_name, "name": original_name,
"description": model_config.description, "description": model_config.description,
"hash": model_config.hash, "original_hash": model_config.original_hash,
"source": model_config.source, "source": model_config.source,
}, },
) )
@ -566,6 +671,10 @@ async def convert_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file # delete the original safetensors file
installer.delete(key) installer.delete(key)
@ -577,66 +686,66 @@ async def convert_model(
return new_config return new_config
# @model_manager_router.put( @model_manager_router.put(
# "/merge", "/merge",
# operation_id="merge", operation_id="merge",
# responses={ responses={
# 200: { 200: {
# "description": "Model converted successfully", "description": "Model converted successfully",
# "content": {"application/json": {"example": example_model_config}}, "content": {"application/json": {"example": example_model_config}},
# }, },
# 400: {"description": "Bad request"}, 400: {"description": "Bad request"},
# 404: {"description": "Model not found"}, 404: {"description": "Model not found"},
# 409: {"description": "There is already a model registered at this location"}, 409: {"description": "There is already a model registered at this location"},
# }, },
# ) )
# async def merge( async def merge(
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# force: bool = Body( force: bool = Body(
# description="Force merging of models created with different versions of diffusers", description="Force merging of models created with different versions of diffusers",
# default=False, default=False,
# ), ),
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
# merge_dest_directory: Optional[str] = Body( merge_dest_directory: Optional[str] = Body(
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)", description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
# default=None, default=None,
# ), ),
# ) -> AnyModelConfig: ) -> AnyModelConfig:
# """ """
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
# ``` ```
# Argument Description [default] Argument Description [default]
# -------- ---------------------- -------- ----------------------
# keys List of 2-3 model keys to merge together. All models must use the same base type. keys List of 2-3 model keys to merge together. All models must use the same base type.
# merged_model_name Name for the merged model [Concat model names] merged_model_name Name for the merged model [Concat model names]
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False] force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
# merge_dest_directory Specify a directory to store the merged model in [models directory] merge_dest_directory Specify a directory to store the merged model in [models directory]
# ``` ```
# """ """
# logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
# try: try:
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
# installer = ApiDependencies.invoker.services.model_manager.install installer = ApiDependencies.invoker.services.model_manager.install
# merger = ModelMerger(installer) merger = ModelMerger(installer)
# model_names = [installer.record_store.get_model(x).name for x in keys] model_names = [installer.record_store.get_model(x).name for x in keys]
# response = merger.merge_diffusion_models_and_save( response = merger.merge_diffusion_models_and_save(
# model_keys=keys, model_keys=keys,
# merged_model_name=merged_model_name or "+".join(model_names), merged_model_name=merged_model_name or "+".join(model_names),
# alpha=alpha, alpha=alpha,
# interp=interp, interp=interp,
# force=force, force=force,
# merge_dest_directory=dest, merge_dest_directory=dest,
# ) )
# except UnknownModelException: except UnknownModelException:
# raise HTTPException( raise HTTPException(
# status_code=404, status_code=404,
# detail=f"One or more of the models '{keys}' not found", detail=f"One or more of the models '{keys}' not found",
# ) )
# except ValueError as e: except ValueError as e:
# raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
# return response return response

View File

@ -173,16 +173,6 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
) )
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation( @invocation(
"create_gradient_mask", "create_gradient_mask",
title="Create Gradient Mask", title="Create Gradient Mask",
@ -203,42 +193,38 @@ class CreateGradientMaskInvocation(BaseInvocation):
) )
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput: def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L") mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur": if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1 # redistribute blur so that the edges are 0 and blur out to 1
blur_tensor = (blur_tensor - 0.5) * 2 blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged": if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold # wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor) blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else: else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold # wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else: # multiply original mask to force actually masked regions to 0
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) blur_tensor = mask_tensor * blur_tensor
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor return DenoiseMaskOutput.build(
expanded_mask = torch.where((blur_tensor < 1), 0, 1) mask_name=mask_name,
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L") masked_latents_name=None,
expanded_image_dto = context.images.save(expanded_mask_image) gradient=True,
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
) )
@ -374,6 +360,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -383,6 +370,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings=c, text_embeddings=c,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold, threshold=0.0, # threshold,
warmup=0.2, # warmup, warmup=0.2, # warmup,
@ -789,7 +777,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
result_latents = pipeline.latents_from_embeddings( (
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents, latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,

View File

@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=key, key=key,
submodel_type=SubModelType.VAE, submodel_type=SubModelType.Vae,
), ),
), ),
) )

View File

@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel_type=SubModelType.VAE, submodel_type=SubModelType.Vae,
), ),
), ),
) )
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel_type=SubModelType.VAE, submodel_type=SubModelType.Vae,
), ),
), ),
) )

View File

@ -256,7 +256,6 @@ class InvokeAIAppConfig(InvokeAISettings):
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development) profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development) profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development) profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other) version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)

View File

@ -18,9 +18,10 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum): class InstallStatus(str, Enum):
"""State of an install job running in the background.""" """State of an install job running in the background."""
@ -150,13 +151,6 @@ ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type") Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
] ]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""
@ -266,6 +260,7 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
): ):
""" """
@ -352,7 +347,6 @@ class ModelInstallServiceBase(ABC):
source: str, source: str,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions. r"""Install the indicated model using heuristics to interpret user intentions.
@ -398,7 +392,7 @@ class ModelInstallServiceBase(ABC):
will override corresponding autoassigned probe fields in the will override corresponding autoassigned probe fields in the
model's config record. Use it to override model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`, `name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, and/or `image_size`. `prediction_type`, `image_size`, and/or `ztsnr_training`.
This will download the model located at `source`, This will download the model located at `source`,
probe it, and install it into the models directory. probe it, and install it into the models directory.

View File

@ -7,6 +7,7 @@ import time
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -20,15 +21,11 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
CheckpointConfigBase,
InvalidModelConfigException, InvalidModelConfigException,
ModelRepoVariant, ModelRepoVariant,
ModelSourceType,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
@ -38,14 +35,12 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
) )
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import ( from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource, CivitaiModelSource,
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
@ -95,6 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False self._running = False
self._session = session self._session = session
self._next_job_id = 0 self._next_job_id = 0
self._metadata_store = record_store.metadata_store # for convenience
@property @property
def app_config(self) -> InvokeAIAppConfig: # noqa D102 def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -143,7 +139,6 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {} config = config or {}
if not config.get("source"): if not config.get("source"):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config) return self._register(model_path, config)
def install_path( def install_path(
@ -153,11 +148,10 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or {} config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
if self._app_config.skip_model_hash: info: AnyModelConfig = self._probe_model(Path(model_path), config)
config["hash"] = uuid_string()
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"): if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix) preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -183,14 +177,13 @@ class ModelInstallService(ModelInstallServiceBase):
source: str, source: str,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values()) variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source), inplace=inplace) source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source): elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource( source_obj = HFModelSource(
repo_id=match.group(1), repo_id=match.group(1),
@ -285,7 +278,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info("Model installer (re)initialized") self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config) search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear() self._models_installed.clear()
@ -379,18 +372,15 @@ class ModelInstallService(ModelInstallServiceBase):
job.bytes = job.total_bytes job.bytes = job.total_bytes
self._signal_job_running(job) self._signal_job_running(job)
job.config_in["source"] = str(job.source) job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
if job.inplace: if job.inplace:
key = self.register_path(job.local_path, job.config_in) key = self.register_path(job.local_path, job.config_in)
else: else:
key = self.install_path(job.local_path, job.config_in) key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
# enter the metadata, if there is any
if job.source_metadata:
self._metadata_store.add_metadata(key, job.source_metadata)
self._signal_job_completed(job) self._signal_job_completed(job)
except InvalidModelConfigException as excp: except InvalidModelConfigException as excp:
@ -476,7 +466,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Moving {model.name} to {new_path}.") self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path) new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix() model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, ModelRecordChanges(path=model.path)) self.record_store.update_model(key, model)
return model return model
def _scan_register(self, model: Path) -> bool: def _scan_register(self, model: Path) -> bool:
@ -528,14 +518,22 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path) move(old_path, new_path)
return new_path return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register( def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str: ) -> str:
config = config or {} key = self._create_key()
if config and not config.get("key", None):
if self._app_config.skip_model_hash: config["key"] = key
config["hash"] = uuid_string()
info = info or ModelProbe.probe(model_path, config) info = info or ModelProbe.probe(model_path, config)
model_path = model_path.absolute() model_path = model_path.absolute()
@ -545,11 +543,11 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix() info.path = model_path.as_posix()
# add 'main' specific fields # add 'main' specific fields
if isinstance(info, CheckpointConfigBase): if hasattr(info, "config"):
# make config relative to our root # make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve() legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix() info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info) self.record_store.add_model(info.key, info)
return info.key return info.key
def _next_id(self) -> int: def _next_id(self) -> int:
@ -570,15 +568,13 @@ class ModelInstallService(ModelInstallServiceBase):
source=source, source=source,
config_in=config or {}, config_in=config or {},
local_path=Path(source.path), local_path=Path(source.path),
inplace=source.inplace or False, inplace=source.inplace,
) )
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token: if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.") self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id( metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
str(source.version_id)
)
assert isinstance(metadata, ModelMetadataWithFiles) assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session) remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files) return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
@ -606,17 +602,15 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially # URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"^https?://civitai.com/": CivitaiMetadataFetch,
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
}
metadata = None metadata = None
fetcher = None for pattern, fetcher in url_patterns.items():
try: if re.match(pattern, str(source.url), re.IGNORECASE):
fetcher = self.get_fetcher_from_url(str(source.url)) metadata = fetcher(self._session).from_url(source.url)
except ValueError: break
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is CivitaiMetadataFetch:
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}") self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles): if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session) remote_files = metadata.download_urls(session=self._session)
@ -631,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model( def _import_remote_model(
self, self,
source: HFModelSource | CivitaiModelSource | URLModelSource, source: ModelSource,
remote_files: List[RemoteModelFile], remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
@ -659,7 +653,7 @@ class ModelInstallService(ModelInstallServiceBase):
# In the event that there is a subfolder specified in the source, # In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid # we need to remove it from the destination path in order to avoid
# creating unwanted subfolders # creating unwanted subfolders
if isinstance(source, HFModelSource) and source.subfolder: if hasattr(source, "subfolder") and source.subfolder:
root = Path(remote_files[0].path.parts[0]) root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder subfolder = root / source.subfolder
else: else:
@ -846,11 +840,3 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{job.source}: model installation was cancelled") self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source)) self._event_bus.emit_model_install_cancelled(str(job.source))
@staticmethod
def get_fetcher_from_url(url: str):
if re.match(r"^https?://civitai.com/", url.lower()):
return CivitaiMetadataFetch
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -0,0 +1,9 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -0,0 +1,65 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -0,0 +1,222 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -6,19 +6,20 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -59,33 +60,11 @@ class ModelSummary(BaseModel):
tags: Set[str] = Field(description="tags associated with model") tags: Set[str] = Field(description="tags associated with model")
class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
class ModelRecordServiceBase(ABC): class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs.""" """Abstract base class for storage and retrieval of model configs."""
@abstractmethod @abstractmethod
def add_model(self, config: AnyModelConfig) -> AnyModelConfig: def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -109,12 +88,13 @@ class ModelRecordServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
""" """
Update the model, returning the updated version. Update the model, returning the updated version.
:param key: Unique key for the model to be updated. :param key: Unique key for the model to be updated
:param changes: A set of changes to apply to this model. Changes are validated before being written. :param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
""" """
pass pass
@ -129,6 +109,40 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
pass
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
pass
@abstractmethod
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
pass
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
pass
@abstractmethod @abstractmethod
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
@ -203,3 +217,21 @@ class ModelRecordServiceBase(ABC):
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'." f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
) )
return model_configs[0] return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -43,7 +43,7 @@ import json
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -53,11 +53,12 @@ from invokeai.backend.model_manager.config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
DuplicateModelException, DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy, ModelRecordOrderBy,
ModelRecordServiceBase, ModelRecordServiceBase,
ModelSummary, ModelSummary,
@ -68,7 +69,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase): def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -77,13 +78,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
"""Return the underlying database.""" """Return the underlying database."""
return self._db return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig: def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -93,19 +95,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions. Can raise DuplicateModelException and InvalidModelConfigException exceptions.
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT INTO models ( INSERT INTO model_config (
id, id,
original_hash,
config config
) )
VALUES (?,?); VALUES (?,?,?);
""", """,
( (
config.key, key,
config.model_dump_json(), record.original_hash,
json_serialized,
), ),
) )
self._db.conn.commit() self._db.conn.commit()
@ -113,12 +119,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e: except sqlite3.IntegrityError as e:
self._db.conn.rollback() self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e): if "model_config.path" in str(e):
msg = f"A model with path '{config.path}' is already installed" msg = f"A model with path '{record.path}' is already installed"
elif "models.name" in str(e): elif "model_config.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed" msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else: else:
msg = f"A model with key '{config.key}' is already installed" msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e raise DuplicateModelException(msg) from e
else: else:
raise e raise e
@ -126,7 +132,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
return self.get_model(config.key) return self.get_model(key)
def del_model(self, key: str) -> None: def del_model(self, key: str) -> None:
""" """
@ -140,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM models DELETE FROM model_config
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -152,20 +158,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
record = self.get_model(key) """
Update the model, returning the updated version.
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
UPDATE models UPDATE model_config
SET SET
config=? config=?
WHERE id=?; WHERE id=?;
@ -192,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config, strftime('%s',updated_at) FROM models SELECT config, strftime('%s',updated_at) FROM model_config
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -213,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
select count(*) FROM models select count(*) FROM model_config
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -239,8 +246,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all If none of the optional filters are passed, will return all
models in the database. models in the database.
""" """
where_clause: list[str] = [] results = []
bindings: list[str] = [] where_clause = []
bindings = []
if model_name: if model_name:
where_clause.append("name=?") where_clause.append("name=?")
bindings.append(model_name) bindings.append(model_name)
@ -257,13 +265,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT config, strftime('%s',updated_at) FROM models select config, strftime('%s',updated_at) FROM model_config
{where}; {where};
""", """,
tuple(bindings), tuple(bindings),
) )
result = self._cursor.fetchall() results = [
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result] ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -272,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config, strftime('%s',updated_at) FROM models SELECT config, strftime('%s',updated_at) FROM model_config
WHERE path=?; WHERE path=?;
""", """,
(str(path),), (str(path),),
@ -283,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]: def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash.""" """Return models with the indicated original_hash."""
results = [] results = []
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config, strftime('%s',updated_at) FROM models SELECT config, strftime('%s',updated_at) FROM model_config
WHERE hash=?; WHERE original_hash=?;
""", """,
(hash,), (hash,),
) )
@ -298,35 +307,83 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
] ]
return results return results
@property
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
store = self.metadata_store
try:
metadata = store.get_metadata(key)
return metadata
except UnknownMetadataException:
return None
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database.""" """Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = { ordering = {
ModelRecordOrderBy.Default: "type, base, format, name", ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
ModelRecordOrderBy.Type: "type", ModelRecordOrderBy.Type: "a.type",
ModelRecordOrderBy.Base: "base", ModelRecordOrderBy.Base: "a.base",
ModelRecordOrderBy.Name: "name", ModelRecordOrderBy.Name: "a.name",
ModelRecordOrderBy.Format: "format", ModelRecordOrderBy.Format: "a.format",
} }
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
"""Fix up results so that there are no null values."""
result: Dict[str, Union[str, int, Set[str]]] = {}
for key, item in summary.items():
result[key] = item or ""
result["tags"] = set(json.loads(summary["tags"] or "[]"))
return result
# Lock so that the database isn't updated while we're doing the two queries. # Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock: with self._db.lock:
# query1: get the total number of model configs # query1: get the total number of model configs
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
select count(*) from models; select count(*) from model_config;
""", """,
(), (),
) )
total = int(self._cursor.fetchone()[0]) total = int(self._cursor.fetchone()[0])
# query2: fetch key fields # query2: fetch key fields from the join of model_config and model_metadata
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT config SELECT a.id as key, a.type, a.base, a.format, a.name,
FROM models json_extract(a.config, '$.description') as description,
json_extract(b.metadata, '$.tags') as tags
FROM model_config AS a
LEFT JOIN model_metadata AS b on a.id=b.id
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ? LIMIT ?
OFFSET ?; OFFSET ?;
@ -337,7 +394,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
), ),
) )
rows = self._cursor.fetchall() rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows] items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
return PaginatedResults( return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
) )

View File

@ -1,35 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC): class SessionProcessorBase(ABC):

View File

@ -2,14 +2,13 @@ import traceback
from contextlib import suppress from contextlib import suppress
from threading import BoundedSemaphore, Thread from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from typing import Callable, Optional, Union from typing import Optional
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@ -17,164 +16,15 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase, SessionRunnerBase from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
):
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
invocation = queue_item.session.next()
if invocation is None:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation.id, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
"""Complete the graph"""
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
)
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
"""Processes sessions from the session queue""" def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent() self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent() self._stop_event = ThreadEvent()
@ -209,7 +59,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event, "cancel_event": self._cancel_event,
}, },
) )
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start() self._thread.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
@ -268,17 +117,112 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear() cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler # If profiling is enabled, start the profiler
if self._profiler is not None: if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id) self._profiler.start(profile_id=self._queue_item.session_id)
# Run the graph # Prepare invocations and take the first
self.session_runner.run(queue_item=self._queue_item) self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats # If we are profiling, stop the profiler and dump the profile & stats
if self._profiler: if self._profiler:
profile_path = self._profiler.stop() profile_path = self._profiler.stop()
@ -286,16 +230,17 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.performance_statistics.dump_stats( self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
) )
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error. # we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError): with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats() self._invoker.services.performance_statistics.reset_stats()
# If we have a on_after_run_session callback, call it # Set the invocation to None to prepare for the next session
if self.on_after_run_session is not None: self._invocation = None
self.on_after_run_session(self._queue_item) else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# The session is complete, immediately poll for next session # The session is complete, immediately poll for next session
self._queue_item = None self._queue_item = None
@ -329,4 +274,3 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear() poll_now_event.clear()
self._queue_item = None self._queue_item = None
self._thread_semaphore.release() self._thread_semaphore.release()
self._invoker.services.logger.debug("Session processor stopped")

View File

@ -9,7 +9,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -36,7 +35,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5()) migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6()) migrator.register_migration(build_migration_6())
migrator.register_migration(build_migration_7())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -1,88 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration7Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_models_table(cursor)
self._drop_old_models_tables(cursor)
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
tables = ["model_records", "model_metadata", "model_tags", "tags"]
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table};")
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the v4.0.0 models table."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS models (
id TEXT NOT NULL PRIMARY KEY,
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS models_updated_at
AFTER UPDATE
ON models FOR EACH ROW
BEGIN
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_7() -> Migration:
"""
Build the migration from database version 6 to 7.
This migration does the following:
- Adds the new models table
- Drops the old model_records, model_metadata, model_tags and tags tables.
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
"""
migration_7 = Migration(
from_version=6,
to_version=7,
callback=Migration7Callback(),
)
return migration_7

View File

@ -3,6 +3,7 @@
import json import json
import sqlite3 import sqlite3
from hashlib import sha1
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -21,7 +22,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory, ModelConfigFactory,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.hash import ModelHash from invokeai.backend.model_manager.hash import FastModelHash
ModelsValidator = TypeAdapter(AnyModelConfig) ModelsValidator = TypeAdapter(AnyModelConfig)
@ -72,27 +73,19 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/") base_type, model_type, model_name = str(model_key).split("/")
try: try:
hash = ModelHash().hash(self.config.models_path / stanza.path) hash = FastModelHash.hash(self.config.models_path / stanza.path)
except OSError: except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type) stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type) stanza["type"] = ModelType(model_type)
stanza["name"] = model_name stanza["name"] = model_name
stanza["original_hash"] = hash stanza["original_hash"] = hash
stanza["current_hash"] = hash stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094 new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -102,7 +95,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}") self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config) self._update_model(key, new_config)
else: else:
self.logger.info(f"Adding model {model_name} with key {new_key}") self.logger.info(f"Adding model {model_name} with key {model_key}")
self._add_model(new_key, new_config) self._add_model(new_key, new_config)
except DuplicateModelException: except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database") self.logger.warning(f"Model {model_name} is already in the database")
@ -150,14 +143,9 @@ class MigrateModelYamlToDb1:
""", """,
( (
key, key,
record.hash, record.original_hash,
json_serialized, json_serialized,
), ),
) )
except sqlite3.IntegrityError as exc: except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -17,8 +17,7 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example. See :class:`Migration` for an example.
""" """
def __call__(self, cursor: sqlite3.Cursor) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None: ...
...
class MigrationError(RuntimeError): class MigrationError(RuntimeError):

View File

@ -0,0 +1,55 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.shared.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
not isinstance(graph, dict)
or "nodes" not in graph
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -25,13 +25,10 @@ from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
import torch import torch
from diffusers.models.modeling_utils import ModelMixin from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
@ -59,8 +56,8 @@ class ModelType(str, Enum):
ONNX = "onnx" ONNX = "onnx"
Main = "main" Main = "main"
VAE = "vae" Vae = "vae"
LoRA = "lora" Lora = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ip_adapter" IPAdapter = "ip_adapter"
@ -76,9 +73,9 @@ class SubModelType(str, Enum):
TextEncoder2 = "text_encoder_2" TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2" Tokenizer2 = "tokenizer_2"
VAE = "vae" Vae = "vae"
VAEDecoder = "vae_decoder" VaeDecoder = "vae_decoder"
VAEEncoder = "vae_encoder" VaeEncoder = "vae_encoder"
Scheduler = "scheduler" Scheduler = "scheduler"
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
@ -96,8 +93,8 @@ class ModelFormat(str, Enum):
Diffusers = "diffusers" Diffusers = "diffusers"
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
LyCORIS = "lycoris" Lycoris = "lycoris"
ONNX = "onnx" Onnx = "onnx"
Olive = "olive" Olive = "olive"
EmbeddingFile = "embedding_file" EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder" EmbeddingFolder = "embedding_folder"
@ -115,186 +112,127 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum): class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format.""" """Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier - empty str DEFAULT = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16" FP16 = "fp16"
FP32 = "fp32" FP32 = "fp32"
ONNX = "onnx" ONNX = "onnx"
OpenVINO = "openvino" OPENVINO = "openvino"
Flax = "flax" FLAX = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
"""Base class for model configuration information.""" """Base class for model configuration information."""
key: str = Field(description="A unique key for this model.", default_factory=uuid_string) path: str = Field(description="filesystem path to the model file or directory")
hash: str = Field(description="The hash of the model file(s).") name: str = Field(description="model name")
path: str = Field( base: BaseModelType = Field(description="base model")
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory." type: ModelType = Field(description="type of the model")
) format: ModelFormat = Field(description="model format")
name: str = Field(description="Name of the model.") key: str = Field(description="unique key for model", default="<NOKEY>")
base: BaseModelType = Field(description="The base model.") original_hash: Optional[str] = Field(
description: Optional[str] = Field(description="Model description", default=None) description="original fasthash of model contents", default=None
source: str = Field(description="The original source of the model (path, URL or repo_id).") ) # this is assigned at install time and will not change
source_type: ModelSourceType = Field(description="The type of source") current_hash: Optional[str] = Field(
source_api_response: Optional[str] = Field( description="current fasthash of model contents", default=None
description="The original API response from the source, as stringified JSON.", default=None ) # if model is converted or otherwise modified, this will hold updated hash
) description: Optional[str] = Field(description="human readable description of the model", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
default_settings: Optional[ModelDefaultSettings] = Field( last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
description="Default settings for this model", default=None
)
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(["key", "type", "format"]) schema["required"].extend(
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
)
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra) model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
json_schema_extra=json_schema_extra,
)
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class CheckpointConfigBase(ModelConfigBase): class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models.""" """Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config_path: str = Field(description="path to the checkpoint model config file") config: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class DiffusersConfigBase(ModelConfigBase): class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models.""" """Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRALyCORISConfig(ModelConfigBase): class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models.""" """Model config for LoRA/Lycoris models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(ModelConfigBase): class VaeCheckpointConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models.""" """Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
class VaeDiffusersConfig(ModelConfigBase):
class VAEDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version).""" """Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = ModelType.VAE type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(_DiffusersConfig):
class ControlNetDiffusersConfig(DiffusersConfigBase):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(_CheckpointConfig):
class ControlNetCheckpointConfig(CheckpointConfigBase):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
class TextualInversionConfig(ModelConfigBase):
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings.""" """Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class TextualInversionFolderConfig(ModelConfigBase): class _MainConfig(ModelConfigBase):
"""Model config for textual inversion embeddings.""" """Model config for main models."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion vae: Optional[str] = Field(default=None)
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False upcast_attention: bool = False
ztsnr_training: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(DiffusersConfigBase): class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main diffusers models.""" """Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag: class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}") """Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
class IPAdapterConfig(ModelConfigBase): class IPAdapterConfig(ModelConfigBase):
@ -304,10 +242,6 @@ class IPAdapterConfig(ModelConfigBase):
image_encoder_model_id: str image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class CLIPVisionDiffusersConfig(ModelConfigBase): class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision.""" """Model config for ClipVision."""
@ -315,65 +249,58 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I.""" """Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
def get_model_discriminator_value(v: Any) -> str: AnyModelConfig = Union[
""" _MainModelConfig,
Computes the discriminator value for a model config. _VaeConfig,
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator _ControlNetConfig,
""" # ModelConfigBase,
format_ = None LoRAConfig,
type_ = None TextualInversionConfig,
if isinstance(v, dict): IPAdapterConfig,
format_ = v.get("format") CLIPVisionDiffusersConfig,
if isinstance(format_, Enum): T2IConfig,
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""
@ -405,6 +332,6 @@ class ModelConfigFactory(object):
assert model is not None assert model is not None
if key: if key:
model.key = key model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None: if timestamp:
model.converted_at = timestamp model.last_modified = timestamp
return model # type: ignore return model # type: ignore

View File

@ -11,175 +11,56 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Optional, Union from typing import Dict, Union
from blake3 import blake3 from imohash import hashfile
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class ModelHash: class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
""" """
Creates a hash of a model using a specified algorithm. Return hexdigest string for model located at model_location.
Args: :param model_location: Path to the model
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/")
```
""" """
model_location = Path(model_location)
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None: if model_location.is_file():
if algorithm == "blake3": return cls._hash_file(model_location)
self._hash_file = self._blake3 elif model_location.is_dir():
elif algorithm in hashlib.algorithms_available: return cls._hash_dir(model_location)
self._hash_file = self._get_hashlib(algorithm)
else: else:
raise ValueError(f"Algorithm {algorithm} not available") raise OSError(f"Not a valid file or directory: {model_location}")
self._file_filter = file_filter or self._default_file_filter @classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
def hash(self, model_path: Union[str, Path]) -> str:
""" """
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation. Fasthash a single file and return its hexdigest.
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the :param model_location: Path to the model file
directory. The final composite hash is always computed using BLAKE3.
Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
""" """
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
model_path = Path(model_path) @classmethod
if model_path.is_file(): def _hash_dir(cls, model_location: Union[str, Path]) -> str:
return self._hash_file(model_path) components: Dict[str, str] = {}
elif model_path.is_dir():
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
def _hash_dir(self, dir: Path) -> str: for root, _dirs, files in os.walk(model_location):
"""Compute the hash for all files in a directory and return a hexdigest. for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
Args: # hash all the model hashes together, using alphabetic file order
dir: Path to the directory md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
Returns: md5.update(fast_hash.encode("utf-8"))
str: Hexdigest of the hash of the directory return md5.hexdigest()
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = []
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file_filter(file):
files.append(Path(root, file))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
algorithm: Hashing algorithm to use
Returns:
A function that hashes a file using the given algorithm
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
Args:
file_path: Path to the file
Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)

View File

@ -13,7 +13,6 @@ from invokeai.backend.model_manager import (
ModelRepoVariant, ModelRepoVariant,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -51,7 +50,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of :param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
if model_config.type is ModelType.Main and not submodel_type: if model_config.type == "main" and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model") raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -81,7 +80,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path) return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
return False return False
def _load_if_needed( def _load_if_needed(
@ -120,7 +119,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs( return calc_model_size_by_fs(
model_path=model_path, model_path=model_path,
subfolder=submodel_type.value if submodel_type else None, subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None, variant=config.repo_variant if hasattr(config, "repo_variant") else None,
) )
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints

View File

@ -15,8 +15,10 @@ Use like this:
""" """
import hashlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type
from ..config import ( from ..config import (
AnyModelConfig, AnyModelConfig,
@ -25,6 +27,8 @@ from ..config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType, SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
) )
from . import ModelLoaderBase from . import ModelLoaderBase
@ -57,9 +61,6 @@ class ModelLoaderRegistryBase(ABC):
""" """
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry: class ModelLoaderRegistry:
""" """
This class allows model loaders to register their type, base and format. This class allows model loaders to register their type, base and format.
@ -70,10 +71,10 @@ class ModelLoaderRegistry:
@classmethod @classmethod
def register( def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]: ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader.""" """Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]: def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
key = cls._to_registry_key(base, type, format) key = cls._to_registry_key(base, type, format)
if key in cls._registry: if key in cls._registry:
raise Exception( raise Exception(
@ -89,15 +90,33 @@ class ModelLoaderRegistry:
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type.""" """Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2) implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation: if not implementation:
raise NotImplementedError( raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
) )
return implementation, config, submodel_type return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@staticmethod @staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:

View File

@ -3,8 +3,8 @@
from pathlib import Path from pathlib import Path
import safetensors
import torch import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -21,15 +20,15 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlNetLoader(GenericDiffusersLoader): class ControlnetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models.""" """Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase): if config.format != ModelFormat.Checkpoint:
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0) and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
): ):
return False return False
@ -38,13 +37,13 @@ class ControlNetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"ControlNet conversion not supported for model type: {config.base}") raise Exception(f"Vae conversion not supported for model type: {config.base}")
else: else:
assert isinstance(config, CheckpointConfigBase) assert hasattr(config, "config")
config_file = config.config_path config_file = config.config
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu") checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(model_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")

View File

@ -3,10 +3,9 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Dict, Optional
from diffusers.configuration_utils import ConfigMixin from diffusers import ConfigMixin, ModelMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
@ -42,7 +41,6 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling # TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
result = None
if submodel_type: if submodel_type:
try: try:
config = self._load_diffusers_config(model_path, config_name="model_index.json") config = self._load_diffusers_config(model_path, config_name="model_index.json")
@ -66,7 +64,6 @@ class GenericDiffusersLoader(ModelLoader):
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e: except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None
return result return result
# TO DO: Add exception handling # TO DO: Add exception handling
@ -78,7 +75,7 @@ class GenericDiffusersLoader(ModelLoader):
result: ModelMixin = getattr(res_type, class_name) result: ModelMixin = getattr(res_type, class_name)
return result return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]: def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name) return ConfigLoader.load_config(model_path, config_name=config_name)
@ -86,8 +83,8 @@ class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files.""" """Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod @classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride] def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Load a diffusrs ConfigMixin configuration.""" """Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name") cls.config_name = kwargs.pop("config_name")
# TODO(psyche): the types on this diffusers method are not correct # Diffusers doesn't provide typing info
return super().load_config(*args, **kwargs) # type: ignore return super().load_config(*args, **kwargs) # type: ignore

View File

@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.") raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter( model = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"), ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=self._torch_dtype, dtype=self._torch_dtype,
) )

View File

@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from .. import ModelLoader, ModelLoaderRegistry from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
class LoraLoader(ModelLoader): class LoraLoader(ModelLoader):
"""Class to load LoRA models.""" """Class to load LoRA models."""

View File

@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(GenericDiffusersLoader): class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models.""" """Class to load onnx models."""

View File

@ -4,8 +4,7 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
@ -17,7 +16,7 @@ from invokeai.backend.model_manager import (
ModelVariantType, ModelVariantType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -55,11 +54,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase): if config.format != ModelFormat.Checkpoint:
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0) and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
): ):
return False return False
@ -74,7 +73,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
) )
config_file = config.config_path config_file = config.config
self._logger.info(f"Converting {model_path} to diffusers format") self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(

View File

@ -3,9 +3,9 @@
from pathlib import Path from pathlib import Path
import safetensors
import torch import torch
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
@ -13,25 +13,24 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader): class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase): if config.format != ModelFormat.Checkpoint:
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0) and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
): ):
return False return False
@ -39,15 +38,16 @@ class VaeLoader(GenericDiffusersLoader):
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TODO(MM2): check whether sdxl VAE models convert. # TO DO: check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}") raise Exception(f"Vae conversion not supported for model type: {config.base}")
else: else:
assert isinstance(config, CheckpointConfigBase) config_file = (
config_file = config.config_path "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu") checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(model_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file) ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig) assert isinstance(ckpt_config, DictConfig)
vae_model = convert_ldm_vae_to_diffusers( vae_model = convert_ldm_vae_to_diffusers(

View File

@ -16,7 +16,6 @@ from diffusers import AutoPipelineForText2Image
from diffusers.utils import logging as dlogging from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from . import ( from . import (
@ -118,6 +117,7 @@ class ModelMerger(object):
config = self._installer.app_config config = self._installer.app_config
store = self._installer.record_store store = self._installer.record_store
base_models: Set[BaseModelType] = set() base_models: Set[BaseModelType] = set()
vae = None
variant = None if self._installer.app_config.full_precision else "fp16" variant = None if self._installer.app_config.full_precision else "fp16"
assert ( assert (
@ -134,6 +134,10 @@ class ModelMerger(object):
"normal" "normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" ), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if key == model_keys[0]:
vae = info.vae
# tally base models used # tally base models used
base_models.add(info.base) base_models.add(info.base)
model_paths.extend([config.models_path / info.path]) model_paths.extend([config.models_path / info.path])
@ -159,10 +163,12 @@ class ModelMerger(object):
# update model's config # update model's config
model_config = self._installer.record_store.get_model(key) model_config = self._installer.record_store.get_model(key)
model_config.name = merged_model_name model_config.update(
model_config.description = f"Merge of models {', '.join(model_names)}" {
"name": merged_model_name,
self._installer.record_store.update_model( "description": f"Merge of models {', '.join(model_names)}",
key, ModelRecordChanges(name=model_config.name, description=model_config.description) "vae": vae,
}
) )
self._installer.record_store.update_model(key, model_config)
return model_config return model_config

View File

@ -25,7 +25,9 @@ from .metadata_base import (
AnyModelRepoMetadataValidator, AnyModelRepoMetadataValidator,
BaseMetadata, BaseMetadata,
CivitaiMetadata, CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata, HuggingFaceMetadata,
LicenseRestrictions,
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
UnknownMetadataException, UnknownMetadataException,
@ -36,8 +38,10 @@ __all__ = [
"AnyModelRepoMetadataValidator", "AnyModelRepoMetadataValidator",
"CivitaiMetadata", "CivitaiMetadata",
"CivitaiMetadataFetch", "CivitaiMetadataFetch",
"CommercialUsage",
"HuggingFaceMetadata", "HuggingFaceMetadata",
"HuggingFaceMetadataFetch", "HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataFetchBase", "ModelMetadataFetchBase",
"BaseMetadata", "BaseMetadata",
"ModelMetadataWithFiles", "ModelMetadataWithFiles",

View File

@ -23,21 +23,22 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words) print(metadata.trained_words)
""" """
import json
import re import re
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Dict, Optional
import requests import requests
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import ( from ..metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
CivitaiMetadata, CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
RemoteModelFile, RemoteModelFile,
UnknownMetadataException, UnknownMetadataException,
) )
@ -51,13 +52,10 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/" CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase): class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai.""" """Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None): def __init__(self, session: Optional[Session] = None):
""" """
Initialize the fetcher with an optional requests.sessions.Session object. Initialize the fetcher with an optional requests.sessions.Session object.
@ -65,7 +63,6 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection. this module without an internet connection.
""" """
self._requests = session or requests.Session() self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
""" """
@ -105,21 +102,22 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`. May raise an `UnknownMetadataException`.
""" """
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json() model_json = self._requests.get(model_url).json()
return self._from_api_response(model_json) return self._from_model_json(model_json)
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata: def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try: try:
version_id = version_id or api_response["modelVersions"][0]["id"] version_id = version_id or model_json["modelVersions"][0]["id"]
except TypeError as excp: except TypeError as excp:
raise UnknownMetadataException from excp raise UnknownMetadataException from excp
# loop till we find the section containing the version requested # loop till we find the section containing the version requested
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id] version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
if not version_sections: if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata") raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0] version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary. # Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")] primary = [x for x in version_json["files"] if x.get("primary")]
@ -136,23 +134,36 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}" url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [ model_files = [
RemoteModelFile( RemoteModelFile(
url=self._get_url_with_api_key(url), url=url,
path=Path(primary_file["name"]), path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024), size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"], sha256=primary_file["hashes"]["SHA256"],
) )
] ]
try:
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_phrases: set[str] = set()
return CivitaiMetadata( return CivitaiMetadata(
id=model_json["id"],
name=version_json["name"], name=version_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
files=model_files, files=model_files,
trigger_phrases=trigger_phrases, download_url=version_json["downloadUrl"],
api_response=json.dumps(version_json), thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
) )
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata: def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
@ -163,14 +174,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
""" """
if model_id is None: if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(self._get_url_with_api_key(version_url)).json() version = self._requests.get(version_url).json()
if error := version.get("error"): if error := version.get("error"):
raise UnknownMetadataException(error) raise UnknownMetadataException(error)
model_id = version["modelId"] model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json() model_json = self._requests.get(model_url).json()
return self._from_api_response(model_json, version_id) return self._from_model_json(model_json, version_id)
@classmethod @classmethod
def from_json(cls, json: str) -> CivitaiMetadata: def from_json(cls, json: str) -> CivitaiMetadata:
@ -178,11 +189,6 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
metadata = CivitaiMetadata.model_validate_json(json) metadata = CivitaiMetadata.model_validate_json(json)
return metadata return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
if "?" in url: def _fix_timezone(date: str) -> str:
return f"{url}&token={self._api_key}" return re.sub(r"Z$", "+00:00", date)
return f"{url}?token={self._api_key}"

View File

@ -13,7 +13,6 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags) print(metadata.tags)
""" """
import json
import re import re
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -24,7 +23,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import ( from ..metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
@ -61,7 +60,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
# Little loop which tries fetching a revision corresponding to the selected variant. # Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default. # If not available, then set variant to None and get the default.
# If this too fails, raise exception. # If this too fails, raise exception.
model_info = None model_info = None
while not model_info: while not model_info:
try: try:
@ -74,24 +72,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
else: else:
variant = None variant = None
files: list[RemoteModelFile] = []
_, name = id.split("/") _, name = id.split("/")
for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
return HuggingFaceMetadata( return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str) id=model_info.id,
author=model_info.author,
name=name,
last_modified=model_info.last_modified,
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
tags=model_info.tags,
files=[
RemoteModelFile(
url=hf_hub_url(id, x.rfilename, revision=variant),
path=Path(name, x.rfilename),
size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None,
)
for x in model_info.siblings
],
) )
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -14,8 +14,10 @@ versions of these fields are intended to be kept in sync with the
remote repo. remote repo.
""" """
from datetime import datetime
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from huggingface_hub import configure_http_backend, hf_hub_url from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, Field, TypeAdapter
@ -32,6 +34,31 @@ class UnknownMetadataException(Exception):
"""Raised when no metadata is available for a model.""" """Raised when no metadata is available for a model."""
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
description="Type of commercial use allowed if no commercial use is allowed.", default=None
)
class RemoteModelFile(BaseModel): class RemoteModelFile(BaseModel):
"""Information about a downloadable file that forms part of a model.""" """Information about a downloadable file that forms part of a model."""
@ -45,6 +72,8 @@ class ModelMetadataBase(BaseModel):
"""Base class for model metadata information.""" """Base class for model metadata information."""
name: str = Field(description="model's name") name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class BaseMetadata(ModelMetadataBase): class BaseMetadata(ModelMetadataBase):
@ -82,16 +111,60 @@ class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai.""" """Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai" type: Literal["civitai"] = "civitai"
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response") id: int = Field(description="Civitai version identifier")
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None) version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_minmax: Tuple[float, float] = Field(
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
) # note: For future use
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
if self.restrictions.AllowCommercialUse is None:
return False
else:
# accommodate schema change
acu = self.restrictions.AllowCommercialUse
commercial_usage = acu if isinstance(acu, set) else {acu}
return CommercialUsage.No not in commercial_usage
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
class HuggingFaceMetadata(ModelMetadataWithFiles): class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace.""" """Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface" type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id") id: str = Field(description="huggingface model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None) tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
def download_urls( def download_urls(
self, self,
@ -120,7 +193,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
# the next step reads model_index.json to determine which subdirectories belong # the next step reads model_index.json to determine which subdirectories belong
# to the model # to the model
if Path(f"{prefix}model_index.json") in paths: if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None) url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
resp = session.get(url) resp = session.get(url)
resp.raise_for_status() resp.raise_for_status()
submodels = resp.json() submodels = resp.json()

View File

@ -0,0 +1,221 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -8,7 +8,6 @@ import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.util import SilenceWarnings from invokeai.backend.util.util import SilenceWarnings
from .config import ( from .config import (
@ -18,12 +17,11 @@ from .config import (
ModelConfigFactory, ModelConfigFactory,
ModelFormat, ModelFormat,
ModelRepoVariant, ModelRepoVariant,
ModelSourceType,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
) )
from .hash import ModelHash from .hash import FastModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any] CkptType = Dict[str, Any]
@ -97,8 +95,8 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE, "AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.VAE, "AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision, "CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter, "T2IAdapter": ModelType.T2IAdapter,
@ -110,6 +108,14 @@ class ModelProbe(object):
) -> None: ) -> None:
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod @classmethod
def probe( def probe(
cls, cls,
@ -131,21 +137,19 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = None model_type = None
if format_type is ModelFormat.Diffusers: if format_type == "diffusers":
model_type = cls.get_model_type_from_folder(model_path) model_type = cls.get_model_type_from_folder(model_path)
else: else:
model_type = cls.get_model_type_from_checkpoint(model_path) model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string())
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type() fields["base"] = fields.get("base") or probe.get_base_type()
@ -157,17 +161,15 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if ( if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] fields["config"] = cls._get_checkpoint_config_path(
and fields["format"] is ModelFormat.Checkpoint
):
fields["config_path"] = cls._get_checkpoint_config_path(
model_path, model_path,
model_type=fields["type"], model_type=fields["type"],
base_type=fields["base"], base_type=fields["base"],
@ -177,7 +179,7 @@ class ModelProbe(object):
# additional fields needed for main non-checkpoint models # additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [ elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.ONNX, ModelFormat.Onnx,
ModelFormat.Olive, ModelFormat.Olive,
ModelFormat.Diffusers, ModelFormat.Diffusers,
]: ]:
@ -186,7 +188,7 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction and fields["prediction_type"] == SchedulerPredictionType.VPrediction
) )
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None)) model_info = ModelConfigFactory.make_config(fields, key=fields.get("key", None))
return model_info return model_info
@classmethod @classmethod
@ -211,11 +213,11 @@ class ModelProbe(object):
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.VAE return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.LoRA return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.LoRA return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
@ -237,7 +239,7 @@ class ModelProbe(object):
if (folder_path / f"learned_embeds.{suffix}").exists(): if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.LoRA return ModelType.Lora
if (folder_path / "unet/model.onnx").exists(): if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists(): if (folder_path / "image_encoder.txt").exists():
@ -283,21 +285,13 @@ class ModelProbe(object):
if possible_conf.exists(): if possible_conf.exists():
return possible_conf.absolute() return possible_conf.absolute()
if model_type is ModelType.Main: if model_type == ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type] config_file = config_file[prediction_type]
elif model_type is ModelType.ControlNet: elif model_type == ModelType.ControlNet:
config_file = ( config_file = (
"../controlnet/cldm_v15.yaml" "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../controlnet/cldm_v21.yaml"
)
elif model_type is ModelType.VAE:
config_file = (
"../stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../stable-diffusion/v2-inference.yaml"
) )
else: else:
raise InvalidModelConfigException( raise InvalidModelConfigException(
@ -503,12 +497,12 @@ class FolderProbeBase(ProbeBase):
if ".fp16" in x.suffixes: if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16 return ModelRepoVariant.FP16
if "openvino_model" in x.name: if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO return ModelRepoVariant.OPENVINO
if "flax_model" in x.name: if "flax_model" in x.name:
return ModelRepoVariant.Flax return ModelRepoVariant.FLAX
if x.suffix == ".onnx": if x.suffix == ".onnx":
return ModelRepoVariant.ONNX return ModelRepoVariant.ONNX
return ModelRepoVariant.Default return ModelRepoVariant.DEFAULT
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
@ -714,8 +708,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
@ -723,8 +717,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@ -13,7 +13,6 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
""" """
import re import re
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
@ -35,7 +34,7 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata, The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`. as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
""" """
variant = variant or ModelRepoVariant.Default variant = variant or ModelRepoVariant.DEFAULT
paths: List[Path] = [] paths: List[Path] = []
root = files[0].parts[0] root = files[0].parts[0]
@ -74,81 +73,64 @@ def filter_files(
return sorted(_filter_by_variant(paths, variant)) return sorted(_filter_by_variant(paths, variant))
@dataclass
class SubfolderCandidate:
path: Path
score: int
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]: def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths.""" """Select the proper variant files from a list of HuggingFace repo_id paths."""
result: set[Path] = set() result = set()
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {} basenames: Dict[Path, Path] = {}
for path in files: for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]: if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX: if variant == ModelRepoVariant.ONNX:
result.add(path) result.add(path)
elif "openvino_model" in path.name: elif "openvino_model" in path.name:
if variant == ModelRepoVariant.OpenVINO: if variant == ModelRepoVariant.OPENVINO:
result.add(path) result.add(path)
elif "flax_model" in path.name: elif "flax_model" in path.name:
if variant == ModelRepoVariant.Flax: if variant == ModelRepoVariant.FLAX:
result.add(path) result.add(path)
elif path.suffix in [".json", ".txt"]: elif path.suffix in [".json", ".txt"]:
result.add(path) result.add(path)
elif variant in [ elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
ModelRepoVariant.FP16, ModelRepoVariant.FP16,
ModelRepoVariant.FP32, ModelRepoVariant.FP32,
ModelRepoVariant.Default, ModelRepoVariant.DEFAULT,
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]: ]:
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
# text encoders:
#
# - text_encoder/model.fp16.safetensors
# - text_encoder/model.safetensors
# - text_encoder/pytorch_model.bin
# - text_encoder/pytorch_model.fp16.bin
#
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
parent = path.parent parent = path.parent
score = 0 suffixes = path.suffixes
if len(suffixes) == 2:
variant_label, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
variant_label = ""
suffix = suffixes[0]
basename = parent / path.stem
if path.suffix == ".safetensors": if previous := basenames.get(basename):
score += 1 if (
previous.suffix != ".safetensors" and suffix == ".safetensors"
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None ): # replace non-safetensors with safetensors when available
basenames[basename] = path
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant if variant_label == f".{variant}":
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT. basenames[basename] = path
if candidate_variant_label == f".{variant}" or ( elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default] basenames[basename] = path
): else:
score += 1 basenames[basename] = path
if parent not in subfolder_weights:
subfolder_weights[parent] = []
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
else: else:
continue continue
for candidate_list in subfolder_weights.values(): for v in basenames.values():
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score) result.add(v)
if highest_score_candidate:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than # If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list # config and text files then we return an empty list
if ( if (
variant variant
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax] and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
and not any(variant.value in x.name for x in result) and not any(variant.value in x.name for x in result)
): ):
return set() return set()

View File

@ -4,11 +4,13 @@ Initialization file for the invokeai.backend.stable_diffusion package
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .seamless import set_seamless # noqa: F401 from .seamless import set_seamless # noqa: F401
__all__ = [ __all__ = [
"PipelineIntermediateState", "PipelineIntermediateState",
"StableDiffusionGeneratorPipeline", "StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent", "InvokeAIDiffuserComponent",
"AttentionMapSaver",
"set_seamless", "set_seamless",
] ]

View File

@ -12,6 +12,7 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
@ -25,9 +26,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
@dataclass @dataclass
@ -38,6 +39,7 @@ class PipelineIntermediateState:
timestep: int timestep: int
latents: torch.Tensor latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None predicted_original: Optional[torch.Tensor] = None
attention_map_saver: Optional[AttentionMapSaver] = None
@dataclass @dataclass
@ -188,6 +190,19 @@ class T2IAdapterData:
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
@ -328,9 +343,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False, gradient_mask: Optional[bool] = False,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if init_timestep.shape[0] == 0: if init_timestep.shape[0] == 0:
return latents return latents, None
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -370,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try: try:
latents = self.generate_latents_from_embeddings( latents, attention_map_saver = self.generate_latents_from_embeddings(
latents, latents,
timesteps, timesteps,
conditioning_data, conditioning_data,
@ -387,7 +402,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if mask is not None and not gradient_mask: if mask is not None and not gradient_mask:
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
return latents return latents, attention_map_saver
def generate_latents_from_embeddings( def generate_latents_from_embeddings(
self, self,
@ -400,22 +415,23 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor: ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
batch_size = latents.shape[0] batch_size = latents.shape[0]
attention_map_saver: Optional[AttentionMapSaver] = None
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents return latents, attention_map_saver
ip_adapter_unet_patcher = None ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
) )
self.use_ip_adapter = False self.use_ip_adapter = False
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
@ -466,6 +482,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
predicted_original = getattr(step_output, "pred_original_sample", None) predicted_original = getattr(step_output, "pred_original_sample", None)
# TODO resuscitate attention map saving
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
if callback is not None: if callback is not None:
callback( callback(
PipelineIntermediateState( PipelineIntermediateState(
@ -475,10 +498,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timestep=int(t), timestep=int(t),
latents=latents, latents=latents,
predicted_original=predicted_original, predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
) )
) )
return latents return latents, attention_map_saver
@torch.inference_mode() @torch.inference_mode()
def step( def step(
@ -520,9 +544,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect. # Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0) ip_adapter_unet_patcher.set_scale(i, 0.0)
# Handle ControlNet(s) # Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None down_block_additional_residuals = None
mid_block_additional_residual = None mid_block_additional_residual = None
down_intrablock_additional_residuals = None
# if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
# elif control_data is not None:
if control_data is not None: if control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step( down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data, control_data=control_data,
@ -532,9 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
) )
# elif t2i_adapter_data is not None:
# Handle T2I-Adapter(s)
down_intrablock_additional_residuals = None
if t2i_adapter_data is not None: if t2i_adapter_data is not None:
accum_adapter_state = None accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data: for single_t2i_adapter_data in t2i_adapter_data:
@ -560,6 +588,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for idx, value in enumerate(single_t2i_adapter_data.adapter_state): for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight accum_adapter_state[idx] += value * t2i_adapter_weight
# down_block_additional_residuals = accum_adapter_state
down_intrablock_additional_residuals = accum_adapter_state down_intrablock_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
@ -568,6 +597,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter

View File

@ -2,4 +2,6 @@
Initialization file for invokeai.models.diffusion Initialization file for invokeai.models.diffusion
""" """
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401 from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401

View File

@ -21,7 +21,11 @@ class ExtraConditioningInfo:
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
embeds: torch.Tensor embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo] extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None): def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype) self.embeds = self.embeds.to(device=device, dtype=dtype)
@ -79,6 +83,7 @@ class ConditioningData:
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
""" """
guidance_rescale_multiplier: float = 0 guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict) scheduler_args: dict[str, Any] = field(default_factory=dict)
""" """
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing(). Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().

View File

@ -3,13 +3,19 @@
import enum import enum
import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Callable, Optional
import diffusers
import psutil
import torch import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype from ...util import torch_dtype
@ -19,14 +25,72 @@ class CrossAttentionType(enum.Enum):
TOKENS = 2 TOKENS = 2
class CrossAttnControlContext: class Context:
def __init__(self, arguments: Arguments): cross_attention_mask: Optional[torch.Tensor]
cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum):
NONE = 0
SAVE = (1,)
APPLY = 2
def __init__(self, arguments: Arguments, step_count: int):
""" """
:param arguments: Arguments for the cross-attention control process :param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
""" """
self.cross_attention_mask: Optional[torch.Tensor] = None self.cross_attention_mask = None
self.cross_attention_index_map: Optional[torch.Tensor] = None self.cross_attention_index_map = None
self.self_cross_attention_action = Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.arguments = arguments self.arguments = arguments
self.step_count = step_count
self.self_cross_attention_module_identifiers = []
self.tokens_cross_attention_module_identifiers = []
self.saved_cross_attention_maps = {}
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.SAVE
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
self.tokens_cross_attention_action = Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step( def get_active_cross_attention_control_types_for_step(
self, percent_through: float = None self, percent_through: float = None
@ -47,8 +111,219 @@ class CrossAttnControlContext:
to_control.append(CrossAttentionType.TOKENS) to_control.append(CrossAttentionType.TOKENS)
return to_control return to_control
def save_slice(
self,
identifier: str,
slice: torch.Tensor,
dim: Optional[int],
offset: int,
slice_size: Optional[int],
):
if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = {
"dim": dim,
"slice_size": slice_size,
"slices": {offset or 0: slice},
}
else:
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext): def get_slice(
self,
identifier: str,
requested_dim: Optional[int],
requested_offset: int,
slice_size: int,
):
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict["dim"] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict["slices"][0]
if saved_attention_dict["dim"] == requested_dim:
if slice_size != saved_attention_dict["slice_size"]:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
)
return saved_attention_dict["slices"][requested_offset]
if saved_attention_dict["dim"] is None:
whole_saved_attention = saved_attention_dict["slices"][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset : requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention["dim"], saved_attention["slice_size"]
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = Context.Action.NONE
self.self_cross_attention_action = Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(
self,
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
):
"""
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current Attention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
"""
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == "cuda":
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == "mps" or q.device.type == "cpu":
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def restore_default_cross_attention(
model,
is_running_diffusers: bool,
restore_attention_processor: Optional[AttentionProcessor] = None,
):
if is_running_diffusers:
unet = model
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
else:
remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -87,6 +362,170 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [
(name, module)
for name, module in model.named_modules()
if isinstance(module, cross_attention_class) and which_attn in name
]
cross_attention_modules_in_model_count = len(attention_module_tuples)
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
"failed or some assumption has changed about the structure of the model itself. Please fix the "
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
"attention map display will not work properly until it is fixed."
)
return attention_module_tuples
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
if context.get_should_save_maps(module.identifier):
# print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
context.save_slice(
module.identifier,
slice_to_save,
dim=dim,
offset=offset,
slice_size=slice_size,
)
elif context.get_should_apply_saved_maps(module.identifier):
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
else:
# just use everything
attention_slice = saved_attention_slice
return attention_slice
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, "name"): # Python 3.10
return error.name == attribute
else: # Python 3.9
return attribute in str(error)
def get_mem_free_total(device):
# only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
InvokeAICrossAttentionMixin.__init__(self)
def _attention(self, query, key, value, attention_mask=None):
# default_result = super()._attention(query, key, value)
if attention_mask is not None:
print(f"{type(self).__name__} ignoring passed-in attention_mask")
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
return hidden_states
## 🧨diffusers implementation follows
"""
# base implementation
class AttnProcessor:
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
@dataclass @dataclass
class SwapCrossAttnContext: class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor modified_text_embeddings: torch.Tensor
@ -94,6 +533,18 @@ class SwapCrossAttnContext:
mask: torch.Tensor # in the target space of the index_map mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def __int__(
self,
cac_types_to_do: [CrossAttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor,
):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do return attn_type in self.cross_attention_types_to_do

View File

@ -0,0 +1,100 @@
import math
from typing import Optional
import torch
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize
class AttentionMapSaver:
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps: dict[str, torch.Tensor] = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f"{key}_{maps.shape[1]}"
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
if pil_image is not None:
pil_image.save(path, "PNG")
def get_stacked_maps_image(self) -> Optional[Image.Image]:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(
torch.swapdims(maps, 0, 1),
[num_tokens, this_maps_height, this_maps_width],
)
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(
maps_normalized_expanded_clamped,
[num_tokens * latents_height, latents_width],
)
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked) * (1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xFF).byte()
return Image.fromarray(merged_bytes.numpy(), mode="L")

View File

@ -17,11 +17,13 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from .cross_attention_control import ( from .cross_attention_control import (
Context,
CrossAttentionType, CrossAttentionType,
CrossAttnControlContext,
SwapCrossAttnContext, SwapCrossAttnContext,
get_cross_attention_modules,
setup_cross_attention_control_attention_processors, setup_cross_attention_control_attention_processors,
) )
from .cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[ ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs] # x, t, conditioning, Optional[cross-attention kwargs]
@ -67,12 +69,14 @@ class InvokeAIDiffuserComponent:
self, self,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo], extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
): ):
old_attn_processors = unet.attn_processors old_attn_processors = unet.attn_processors
try: try:
self.cross_attention_control_context = CrossAttnControlContext( self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args, arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
) )
setup_cross_attention_control_attention_processors( setup_cross_attention_control_attention_processors(
unet, unet,
@ -83,6 +87,27 @@ class InvokeAIDiffuserComponent:
finally: finally:
self.cross_attention_control_context = None self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors) unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_controlnet_step( def do_controlnet_step(
self, self,
@ -199,47 +224,51 @@ class InvokeAIDiffuserComponent:
self, self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: torch.Tensor, timestep: torch.Tensor,
conditioning_data: ConditioningData, conditioning_data, # TODO: type
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet **kwargs,
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count percent_through = step_index / total_step_count
cross_attention_control_types_to_do = ( cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) percent_through
) )
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control or self.sequential_guidance: if wants_cross_attention_control:
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention (
# control is currently only supported in sequential mode. unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample,
timestep,
conditioning_data,
cross_attention_control_types_to_do,
**kwargs,
)
elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
x=sample, sample,
sigma=timestep, timestep,
conditioning_data=conditioning_data, conditioning_data,
cross_attention_control_types_to_do=cross_attention_control_types_to_do, **kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
) )
else: else:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
x=sample, sample,
sigma=timestep, timestep,
conditioning_data=conditioning_data, conditioning_data,
down_block_additional_residuals=down_block_additional_residuals, **kwargs,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -306,15 +335,7 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning( def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
self,
x,
sigma,
conditioning_data: ConditioningData,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage. the cost of higher memory usage.
""" """
@ -362,10 +383,8 @@ class InvokeAIDiffuserComponent:
both_conditionings, both_conditionings,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -375,17 +394,14 @@ class InvokeAIDiffuserComponent:
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
cross_attention_control_types_to_do: list[CrossAttentionType], **kwargs,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed. slower execution speed.
""" """
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # low-memory sequential path
# and T2I-Adapter residuals into two chunks.
uncond_down_block, cond_down_block = None, None uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None: if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], [] uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals: for down_block in down_block_additional_residuals:
@ -394,6 +410,7 @@ class InvokeAIDiffuserComponent:
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None: if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], [] uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals: for down_intrablock in down_intrablock_additional_residuals:
@ -402,29 +419,12 @@ class InvokeAIDiffuserComponent:
cond_down_intrablock.append(_cond_down) cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# If cross-attention control is enabled, prepare the SwapCrossAttnContext. # Run unconditional UNet denoising.
cross_attn_processor_context = None
if self.cross_attention_control_context is not None:
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
# will be populated before the conditioned pass.
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
index_map=self.cross_attention_control_context.cross_attention_index_map,
mask=self.cross_attention_control_context.cross_attention_mask,
cross_attention_types_to_do=[],
)
#####################
# Unconditioned pass
#####################
cross_attention_kwargs = None cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs = {
@ -434,11 +434,6 @@ class InvokeAIDiffuserComponent:
] ]
} }
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl: if is_sdxl:
@ -447,7 +442,6 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
} }
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
@ -457,15 +451,11 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock, down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
################### # Run conditional UNet denoising.
# Conditioned pass
###################
cross_attention_kwargs = None cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs = {
@ -475,12 +465,6 @@ class InvokeAIDiffuserComponent:
] ]
} }
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
@ -488,7 +472,6 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.text_embeddings.add_time_ids,
} }
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
@ -498,6 +481,89 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock, down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
sigma,
conditioning_data,
cross_attention_control_types_to_do,
**kwargs,
):
context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.text_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -567,3 +633,54 @@ class InvokeAIDiffuserComponent:
self.last_percent_through = percent_through self.last_percent_through = percent_through
return latents.to(device=dev) return latents.to(device=dev)
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = (
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = math.ceil(len(conditionings) / 2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index * 2
chunk_size = min(2, len(conditionings) - offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset : offset + 2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:

View File

@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely = top_of_table self.nextrely = top_of_table
self.lora_models = self.add_model_widgets( self.lora_models = self.add_model_widgets(
model_type=ModelType.LoRA, model_type=ModelType.Lora,
window_width=window_width, window_width=window_width,
) )
bottom_of_table = max(bottom_of_table, self.nextrely) bottom_of_table = max(bottom_of_table, self.nextrely)

View File

@ -30,7 +30,7 @@
"lint:prettier": "prettier --check .", "lint:prettier": "prettier --check .",
"lint:tsc": "tsc --noEmit", "lint:tsc": "tsc --noEmit",
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*", "lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
"fix": "eslint --fix . && prettier --log-level warn --write .", "fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
"preinstall": "npx only-allow pnpm", "preinstall": "npx only-allow pnpm",
"storybook": "storybook dev -p 6006", "storybook": "storybook dev -p 6006",
"build-storybook": "storybook build", "build-storybook": "storybook build",

View File

@ -134,6 +134,8 @@
"loadMore": "Mehr laden", "loadMore": "Mehr laden",
"noImagesInGallery": "Keine Bilder in der Galerie", "noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade", "loading": "Lade",
"preparingDownload": "bereite Download vor",
"preparingDownloadFailed": "Problem beim Download vorbereiten",
"deleteImage": "Lösche Bild", "deleteImage": "Lösche Bild",
"copy": "Kopieren", "copy": "Kopieren",
"download": "Runterladen", "download": "Runterladen",
@ -965,7 +967,7 @@
"resumeFailed": "Problem beim Fortsetzen des Prozesses", "resumeFailed": "Problem beim Fortsetzen des Prozesses",
"pruneFailed": "Problem beim leeren der Warteschlange", "pruneFailed": "Problem beim leeren der Warteschlange",
"pauseTooltip": "Prozess anhalten", "pauseTooltip": "Prozess anhalten",
"back": "Ende", "back": "Hinten",
"resumeSucceeded": "Prozess wird fortgesetzt", "resumeSucceeded": "Prozess wird fortgesetzt",
"resumeTooltip": "Prozess wieder aufnehmen", "resumeTooltip": "Prozess wieder aufnehmen",
"time": "Zeit", "time": "Zeit",

View File

@ -78,7 +78,6 @@
"aboutDesc": "Using Invoke for work? Check out:", "aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power", "aboutHeading": "Own Your Creative Power",
"accept": "Accept", "accept": "Accept",
"add": "Add",
"advanced": "Advanced", "advanced": "Advanced",
"advancedOptions": "Advanced Options", "advancedOptions": "Advanced Options",
"ai": "ai", "ai": "ai",
@ -304,12 +303,6 @@
"method": "High Resolution Fix Method" "method": "High Resolution Fix Method"
} }
}, },
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noPromptTriggers": "No triggers available",
"noMatchingTriggers": "No matching triggers"
},
"embedding": { "embedding": {
"addEmbedding": "Add Embedding", "addEmbedding": "Add Embedding",
"incompatibleModel": "Incompatible base model:", "incompatibleModel": "Incompatible base model:",
@ -741,8 +734,6 @@
"customConfig": "Custom Config", "customConfig": "Custom Config",
"customConfigFileLocation": "Custom Config File Location", "customConfigFileLocation": "Custom Config File Location",
"customSaveLocation": "Custom Save Location", "customSaveLocation": "Custom Save Location",
"defaultSettings": "Default Settings",
"defaultSettingsSaved": "Default Settings Saved",
"delete": "Delete", "delete": "Delete",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
@ -777,7 +768,6 @@
"mergedModelName": "Merged Model Name", "mergedModelName": "Merged Model Name",
"mergedModelSaveLocation": "Save Location", "mergedModelSaveLocation": "Save Location",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"metadata": "Metadata",
"model": "Model", "model": "Model",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelConversionFailed": "Model Conversion Failed", "modelConversionFailed": "Model Conversion Failed",
@ -849,12 +839,9 @@
"statusConverting": "Converting", "statusConverting": "Converting",
"syncModels": "Sync Models", "syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.", "syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention", "upcastAttention": "Upcast Attention",
"updateModel": "Update Model", "updateModel": "Update Model",
"useCustomConfig": "Use Custom Config", "useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
"v1": "v1", "v1": "v1",
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",
"v2_base": "v2 (512px)", "v2_base": "v2 (512px)",
@ -873,7 +860,6 @@
"models": { "models": {
"addLora": "Add LoRA", "addLora": "Add LoRA",
"allLoRAsAdded": "All LoRAs added", "allLoRAsAdded": "All LoRAs added",
"concepts": "Concepts",
"loraAlreadyAdded": "LoRA already added", "loraAlreadyAdded": "LoRA already added",
"esrganModel": "ESRGAN Model", "esrganModel": "ESRGAN Model",
"loading": "loading", "loading": "loading",

View File

@ -505,6 +505,8 @@
"seamLowThreshold": "Bajo", "seamLowThreshold": "Bajo",
"coherencePassHeader": "Parámetros de la coherencia", "coherencePassHeader": "Parámetros de la coherencia",
"compositingSettingsHeader": "Ajustes de la composición", "compositingSettingsHeader": "Ajustes de la composición",
"coherenceSteps": "Pasos",
"coherenceStrength": "Fuerza",
"patchmatchDownScaleSize": "Reducir a escala", "patchmatchDownScaleSize": "Reducir a escala",
"coherenceMode": "Modo" "coherenceMode": "Modo"
}, },

View File

@ -114,8 +114,7 @@
"checkpoint": "Checkpoint", "checkpoint": "Checkpoint",
"safetensors": "Safetensors", "safetensors": "Safetensors",
"ai": "ia", "ai": "ia",
"file": "File", "file": "File"
"toResolve": "Da risolvere"
}, },
"gallery": { "gallery": {
"generations": "Generazioni", "generations": "Generazioni",
@ -143,6 +142,8 @@
"copy": "Copia", "copy": "Copia",
"download": "Scarica", "download": "Scarica",
"setCurrentImage": "Imposta come immagine corrente", "setCurrentImage": "Imposta come immagine corrente",
"preparingDownload": "Preparazione del download",
"preparingDownloadFailed": "Problema durante la preparazione del download",
"downloadSelection": "Scarica gli elementi selezionati", "downloadSelection": "Scarica gli elementi selezionati",
"noImageSelected": "Nessuna immagine selezionata", "noImageSelected": "Nessuna immagine selezionata",
"deleteSelection": "Elimina la selezione", "deleteSelection": "Elimina la selezione",
@ -608,6 +609,8 @@
"seamLowThreshold": "Basso", "seamLowThreshold": "Basso",
"seamHighThreshold": "Alto", "seamHighThreshold": "Alto",
"coherencePassHeader": "Passaggio di coerenza", "coherencePassHeader": "Passaggio di coerenza",
"coherenceSteps": "Passi",
"coherenceStrength": "Forza",
"compositingSettingsHeader": "Impostazioni di composizione", "compositingSettingsHeader": "Impostazioni di composizione",
"patchmatchDownScaleSize": "Ridimensiona", "patchmatchDownScaleSize": "Ridimensiona",
"coherenceMode": "Modalità", "coherenceMode": "Modalità",
@ -1397,6 +1400,19 @@
"Regola la maschera." "Regola la maschera."
] ]
}, },
"compositingCoherenceSteps": {
"heading": "Passi",
"paragraphs": [
"Numero di passi utilizzati nel Passaggio di Coerenza.",
"Simile ai passi di generazione."
]
},
"compositingBlur": {
"heading": "Sfocatura",
"paragraphs": [
"Il raggio di sfocatura della maschera."
]
},
"compositingCoherenceMode": { "compositingCoherenceMode": {
"heading": "Modalità", "heading": "Modalità",
"paragraphs": [ "paragraphs": [
@ -1415,6 +1431,13 @@
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint." "Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
] ]
}, },
"compositingStrength": {
"heading": "Forza",
"paragraphs": [
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
"Simile alla forza di riduzione del rumore."
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.", "Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",

View File

@ -123,6 +123,8 @@
"autoSwitchNewImages": "새로운 이미지로 자동 전환", "autoSwitchNewImages": "새로운 이미지로 자동 전환",
"loading": "불러오는 중", "loading": "불러오는 중",
"unableToLoad": "갤러리를 로드할 수 없음", "unableToLoad": "갤러리를 로드할 수 없음",
"preparingDownload": "다운로드 준비",
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
"singleColumnLayout": "단일 열 레이아웃", "singleColumnLayout": "단일 열 레이아웃",
"image": "이미지", "image": "이미지",
"loadMore": "더 불러오기", "loadMore": "더 불러오기",

View File

@ -97,6 +97,8 @@
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.", "featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
"loading": "Bezig met laden", "loading": "Bezig met laden",
"unableToLoad": "Kan galerij niet laden", "unableToLoad": "Kan galerij niet laden",
"preparingDownload": "Bezig met voorbereiden van download",
"preparingDownloadFailed": "Fout bij voorbereiden van download",
"downloadSelection": "Download selectie", "downloadSelection": "Download selectie",
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:", "currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
"copy": "Kopieer", "copy": "Kopieer",
@ -533,6 +535,8 @@
"coherencePassHeader": "Coherentiestap", "coherencePassHeader": "Coherentiestap",
"maskBlur": "Vervaag", "maskBlur": "Vervaag",
"maskBlurMethod": "Vervagingsmethode", "maskBlurMethod": "Vervagingsmethode",
"coherenceSteps": "Stappen",
"coherenceStrength": "Sterkte",
"seamHighThreshold": "Hoog", "seamHighThreshold": "Hoog",
"seamLowThreshold": "Laag", "seamLowThreshold": "Laag",
"invoke": { "invoke": {
@ -1135,6 +1139,13 @@
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen." "Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
] ]
}, },
"compositingCoherenceSteps": {
"heading": "Stappen",
"paragraphs": [
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
"Gelijk aan de hoofdparameter Stappen."
]
},
"dynamicPrompts": { "dynamicPrompts": {
"paragraphs": [ "paragraphs": [
"Dynamische prompts vormt een enkele prompt om in vele.", "Dynamische prompts vormt een enkele prompt om in vele.",
@ -1149,6 +1160,12 @@
], ],
"heading": "VAE" "heading": "VAE"
}, },
"compositingBlur": {
"heading": "Vervaging",
"paragraphs": [
"De vervagingsstraal van het masker."
]
},
"paramIterations": { "paramIterations": {
"paragraphs": [ "paragraphs": [
"Het aantal te genereren afbeeldingen.", "Het aantal te genereren afbeeldingen.",
@ -1223,6 +1240,13 @@
], ],
"heading": "Ontruisingssterkte" "heading": "Ontruisingssterkte"
}, },
"compositingStrength": {
"heading": "Sterkte",
"paragraphs": [
"Ontruisingssterkte voor de coherentiefase.",
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.", "Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",

View File

@ -143,6 +143,8 @@
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений", "problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
"loading": "Загрузка", "loading": "Загрузка",
"unableToLoad": "Невозможно загрузить галерею", "unableToLoad": "Невозможно загрузить галерею",
"preparingDownload": "Подготовка к скачиванию",
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
"image": "изображение", "image": "изображение",
"drop": "перебросить", "drop": "перебросить",
"problemDeletingImages": "Проблема с удалением изображений", "problemDeletingImages": "Проблема с удалением изображений",
@ -610,7 +612,9 @@
"maskBlurMethod": "Метод размытия", "maskBlurMethod": "Метод размытия",
"seamLowThreshold": "Низкий", "seamLowThreshold": "Низкий",
"seamHighThreshold": "Высокий", "seamHighThreshold": "Высокий",
"coherenceSteps": "Шагов",
"coherencePassHeader": "Порог Coherence", "coherencePassHeader": "Порог Coherence",
"coherenceStrength": "Сила",
"compositingSettingsHeader": "Настройки компоновки", "compositingSettingsHeader": "Настройки компоновки",
"invoke": { "invoke": {
"noNodesInGraph": "Нет узлов в графе", "noNodesInGraph": "Нет узлов в графе",
@ -1317,6 +1321,13 @@
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL." "Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
] ]
}, },
"compositingCoherenceSteps": {
"heading": "Шаги",
"paragraphs": [
"Количество шагов снижения шума, используемых при прохождении когерентности.",
"То же, что и основной параметр «Шаги»."
]
},
"dynamicPrompts": { "dynamicPrompts": {
"paragraphs": [ "paragraphs": [
"Динамические запросы превращают одно приглашение на множество.", "Динамические запросы превращают одно приглашение на множество.",
@ -1331,6 +1342,12 @@
], ],
"heading": "VAE" "heading": "VAE"
}, },
"compositingBlur": {
"heading": "Размытие",
"paragraphs": [
"Радиус размытия маски."
]
},
"paramIterations": { "paramIterations": {
"paragraphs": [ "paragraphs": [
"Количество изображений, которые нужно сгенерировать.", "Количество изображений, которые нужно сгенерировать.",
@ -1405,6 +1422,13 @@
], ],
"heading": "Шумоподавление" "heading": "Шумоподавление"
}, },
"compositingStrength": {
"heading": "Сила",
"paragraphs": [
null,
"То же, что параметр «Сила шумоподавления img2img»."
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.", "Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",

View File

@ -355,6 +355,7 @@
"starImage": "Yıldız Koy", "starImage": "Yıldız Koy",
"download": "İndir", "download": "İndir",
"deleteSelection": "Seçileni Sil", "deleteSelection": "Seçileni Sil",
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
"problemDeletingImages": "Görsel Silmede Sorun", "problemDeletingImages": "Görsel Silmede Sorun",
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.", "featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
"galleryImageResetSize": "Boyutu Resetle", "galleryImageResetSize": "Boyutu Resetle",
@ -376,6 +377,7 @@
"setCurrentImage": "Çalışma Görseli Yap", "setCurrentImage": "Çalışma Görseli Yap",
"unableToLoad": "Galeri Yüklenemedi", "unableToLoad": "Galeri Yüklenemedi",
"downloadSelection": "Seçileni İndir", "downloadSelection": "Seçileni İndir",
"preparingDownload": "İndirmeye Hazırlanıyor",
"singleColumnLayout": "Tek Sütun Düzen", "singleColumnLayout": "Tek Sütun Düzen",
"generations": ıktılar", "generations": ıktılar",
"showUploads": "Yüklenenleri Göster", "showUploads": "Yüklenenleri Göster",
@ -721,6 +723,7 @@
"clipSkip": "CLIP Atlama", "clipSkip": "CLIP Atlama",
"randomizeSeed": "Rastgele Tohum", "randomizeSeed": "Rastgele Tohum",
"cfgScale": "CFG Ölçeği", "cfgScale": "CFG Ölçeği",
"coherenceStrength": "Etki",
"controlNetControlMode": "Yönetim Kipi", "controlNetControlMode": "Yönetim Kipi",
"general": "Genel", "general": "Genel",
"img2imgStrength": "Görselden Görsel Ölçüsü", "img2imgStrength": "Görselden Görsel Ölçüsü",
@ -790,6 +793,7 @@
"cfgRescaleMultiplier": "CFG Rescale Çarpanı", "cfgRescaleMultiplier": "CFG Rescale Çarpanı",
"cfgRescale": "CFG Rescale", "cfgRescale": "CFG Rescale",
"coherencePassHeader": "Uyum Geçişi", "coherencePassHeader": "Uyum Geçişi",
"coherenceSteps": "Adım",
"infillMethod": "Doldurma Yöntemi", "infillMethod": "Doldurma Yöntemi",
"maskBlurMethod": "Bulandırma Yöntemi", "maskBlurMethod": "Bulandırma Yöntemi",
"steps": "Adım", "steps": "Adım",

View File

@ -136,6 +136,8 @@
"copy": "复制", "copy": "复制",
"download": "下载", "download": "下载",
"setCurrentImage": "设为当前图像", "setCurrentImage": "设为当前图像",
"preparingDownload": "准备下载",
"preparingDownloadFailed": "准备下载时出现问题",
"downloadSelection": "下载所选内容", "downloadSelection": "下载所选内容",
"noImageSelected": "无选中的图像", "noImageSelected": "无选中的图像",
"deleteSelection": "删除所选内容", "deleteSelection": "删除所选内容",
@ -614,9 +616,11 @@
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。" "incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
}, },
"patchmatchDownScaleSize": "缩小", "patchmatchDownScaleSize": "缩小",
"coherenceSteps": "步数",
"clipSkip": "CLIP 跳过层", "clipSkip": "CLIP 跳过层",
"compositingSettingsHeader": "合成设置", "compositingSettingsHeader": "合成设置",
"useCpuNoise": "使用 CPU 噪声", "useCpuNoise": "使用 CPU 噪声",
"coherenceStrength": "强度",
"enableNoiseSettings": "启用噪声设置", "enableNoiseSettings": "启用噪声设置",
"coherenceMode": "模式", "coherenceMode": "模式",
"cpuNoise": "CPU 噪声", "cpuNoise": "CPU 噪声",
@ -1398,6 +1402,19 @@
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸SDXL 模型使用等效 1024x1024 的尺寸。" "图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸SDXL 模型使用等效 1024x1024 的尺寸。"
] ]
}, },
"compositingCoherenceSteps": {
"heading": "步数",
"paragraphs": [
"一致性层中使用的去噪步数。",
"与主参数中的步数相同。"
]
},
"compositingBlur": {
"heading": "模糊",
"paragraphs": [
"遮罩模糊半径。"
]
},
"noiseUseCPU": { "noiseUseCPU": {
"heading": "使用 CPU 噪声", "heading": "使用 CPU 噪声",
"paragraphs": [ "paragraphs": [
@ -1450,6 +1467,13 @@
"第二轮去噪有助于合成内补/外扩图像。" "第二轮去噪有助于合成内补/外扩图像。"
] ]
}, },
"compositingStrength": {
"heading": "强度",
"paragraphs": [
"一致性层使用的去噪强度。",
"去噪强度与图生图的参数相同。"
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。", "生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",

View File

@ -55,8 +55,6 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>; export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@ -153,7 +151,5 @@ addFirstListImagesListener(startAppListening);
// Ad-hoc upscale workflwo // Ad-hoc upscale workflwo
addUpscaleRequestedListener(startAppListening); addUpscaleRequestedListener(startAppListening);
// Prompts // Dynamic prompts
addDynamicPromptsListener(startAppListening); addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@ -7,10 +7,8 @@ import {
selectAllT2IAdapters, selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice'; import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models'; import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
@ -26,9 +24,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
const log = logger('models'); const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`); log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
const state = getState(); const currentModel = getState().generation.model;
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload); const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) { if (models.length === 0) {
@ -43,29 +39,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
return; return;
} }
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]); const result = zParameterModel.safeParse(models[0]);
if (!result.success) { if (!result.success) {

View File

@ -1,96 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
setCfgScale,
setScheduler,
setSteps,
vaePrecisionChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterPrecision,
isParameterScheduler,
isParameterSteps,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const currentModel = state.generation.model;
if (!currentModel) {
return;
}
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
if (!modelConfig || !modelConfig.default_settings) {
return;
}
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
}
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
},
});
};

View File

@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.bytes = bytes; modelImport.bytes = bytes;
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id } = action.payload.data; const { id } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'completed'; modelImport.status = 'completed';
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
dispatch(api.util.invalidateTags(['Model'])); dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
}, },
}); });
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'error'; modelImport.status = 'error';

View File

@ -1,5 +1,4 @@
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt'; import type { O } from 'ts-toolbelt';
@ -83,8 +82,6 @@ export type AppConfig = {
guidance: NumericalParameterConfig; guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig; cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig; img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
// Canvas // Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model boundingBoxWidth: NumericalParameterConfig; // initial value comes from model

View File

@ -8,15 +8,15 @@ type Props = {
onOpen: () => void; onOpen: () => void;
}; };
export const AddPromptTriggerButton = memo((props: Props) => { export const AddEmbeddingButton = memo((props: Props) => {
const { onOpen, isOpen } = props; const { onOpen, isOpen } = props;
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<Tooltip label={t('prompt.addPromptTrigger')}> <Tooltip label={t('embedding.addEmbedding')}>
<IconButton <IconButton
variant="promptOverlay" variant="promptOverlay"
isDisabled={isOpen} isDisabled={isOpen}
aria-label={t('prompt.addPromptTrigger')} aria-label={t('embedding.addEmbedding')}
icon={<PiCodeBold />} icon={<PiCodeBold />}
onClick={onOpen} onClick={onOpen}
/> />
@ -24,4 +24,4 @@ export const AddPromptTriggerButton = memo((props: Props) => {
); );
}); });
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton'; AddEmbeddingButton.displayName = 'AddEmbeddingButton';

View File

@ -1,9 +1,9 @@
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library'; import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect'; import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
import type { PromptPopoverProps } from 'features/prompt/types'; import type { EmbeddingPopoverProps } from 'features/embedding/types';
import { memo } from 'react'; import { memo } from 'react';
export const PromptPopover = memo((props: PromptPopoverProps) => { export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
const { onSelect, isOpen, onClose, width, children } = props; const { onSelect, isOpen, onClose, width, children } = props;
return ( return (
@ -14,7 +14,7 @@ export const PromptPopover = memo((props: PromptPopoverProps) => {
openDelay={0} openDelay={0}
closeDelay={0} closeDelay={0}
closeOnBlur={true} closeOnBlur={true}
returnFocusOnClose={false} returnFocusOnClose={true}
isLazy isLazy
> >
<PopoverAnchor>{children}</PopoverAnchor> <PopoverAnchor>{children}</PopoverAnchor>
@ -27,11 +27,11 @@ export const PromptPopover = memo((props: PromptPopoverProps) => {
borderStyle="solid" borderStyle="solid"
> >
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}> <PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} /> <EmbeddingSelect onClose={onClose} onSelect={onSelect} />
</PopoverBody> </PopoverBody>
</PopoverContent> </PopoverContent>
</Popover> </Popover>
); );
}); });
PromptPopover.displayName = 'PromptPopover'; EmbeddingPopover.displayName = 'EmbeddingPopover';

View File

@ -0,0 +1,21 @@
import type { Meta, StoryObj } from '@storybook/react';
import { EmbeddingSelect } from './EmbeddingSelect';
import type { EmbeddingSelectProps } from './types';
const meta: Meta<typeof EmbeddingSelect> = {
title: 'Feature/Prompt/EmbeddingSelect',
tags: ['autodocs'],
component: EmbeddingSelect,
};
export default meta;
type Story = StoryObj<typeof EmbeddingSelect>;
const Component = (props: EmbeddingSelectProps) => {
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
};
export const Default: Story = {
render: Component,
};

View File

@ -0,0 +1,67 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionModelConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionModelConfig | null) => {
if (!embedding) {
return;
}
onSelect(embedding.name);
},
[onSelect]
);
const { options, onChange } = useGroupedModelCombobox({
modelEntities: data,
getIsDisabled,
onChange: _onChange,
});
return (
<FormControl>
<Combobox
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
noOptionsMessage={noOptionsMessage}
onChange={onChange}
onMenuClose={onClose}
data-testid="add-embedding"
sx={selectStyles}
/>
</FormControl>
);
});
EmbeddingSelect.displayName = 'EmbeddingSelect';
const selectStyles: ChakraProps['sx'] = {
w: 'full',
};

View File

@ -1,12 +1,12 @@
import type { PropsWithChildren } from 'react'; import type { PropsWithChildren } from 'react';
export type PromptTriggerSelectProps = { export type EmbeddingSelectProps = {
onSelect: (v: string) => void; onSelect: (v: string) => void;
onClose: () => void; onClose: () => void;
}; };
export type PromptPopoverProps = PropsWithChildren & export type EmbeddingPopoverProps = PropsWithChildren &
PromptTriggerSelectProps & { EmbeddingSelectProps & {
isOpen: boolean; isOpen: boolean;
width?: number | string; width?: number | string;
}; };

View File

@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
import { useCallback } from 'react'; import { useCallback } from 'react';
import { flushSync } from 'react-dom'; import { flushSync } from 'react-dom';
type UseInsertTriggerArg = { type UseInsertEmbeddingArg = {
prompt: string; prompt: string;
textareaRef: RefObject<HTMLTextAreaElement>; textareaRef: RefObject<HTMLTextAreaElement>;
onChange: (v: string) => void; onChange: (v: string) => void;
}; };
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => { export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
const { isOpen, onClose, onOpen } = useDisclosure(); const { isOpen, onClose, onOpen } = useDisclosure();
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback( const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
[_onChange] [_onChange]
); );
const insertTrigger = useCallback( const insertEmbedding = useCallback(
(v: string) => { (v: string) => {
if (!textareaRef.current) { if (!textareaRef.current) {
return; return;
} }
// this is where we insert the trigger // this is where we insert the TI trigger
const caret = textareaRef.current.selectionStart; const caret = textareaRef.current.selectionStart;
if (isNil(caret)) { if (isNil(caret)) {
@ -35,9 +35,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
let newPrompt = prompt.slice(0, caret); let newPrompt = prompt.slice(0, caret);
newPrompt += `${v}`; if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
// we insert the cursor after the end of trigger newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length; const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret); newPrompt += prompt.slice(caret);
@ -47,7 +51,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
_onChange(newPrompt); _onChange(newPrompt);
}); });
// set the cursor position to just after the trigger // set the caret position to just after the TI trigger
textareaRef.current.selectionStart = finalCaretPos; textareaRef.current.selectionStart = finalCaretPos;
textareaRef.current.selectionEnd = finalCaretPos; textareaRef.current.selectionEnd = finalCaretPos;
}, },
@ -58,17 +62,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
textareaRef.current?.focus(); textareaRef.current?.focus();
}, [textareaRef]); }, [textareaRef]);
const handleClosePopover = useCallback(() => { const handleClose = useCallback(() => {
onClose(); onClose();
onFocus(); onFocus();
}, [onFocus, onClose]); }, [onFocus, onClose]);
const onSelect = useCallback( const onSelectEmbedding = useCallback(
(v: string) => { (v: string) => {
insertTrigger(v); insertEmbedding(v);
handleClosePopover(); handleClose();
}, },
[handleClosePopover, insertTrigger] [handleClose, insertEmbedding]
); );
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback( const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
@ -86,7 +90,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
isOpen, isOpen,
onClose, onClose,
onOpen, onOpen,
onSelect, onSelectEmbedding,
onKeyDown, onKeyDown,
onFocus, onFocus,
}; };

View File

@ -59,7 +59,7 @@ const LoRASelect = () => {
return ( return (
<FormControl isDisabled={!options.length}> <FormControl isDisabled={!options.length}>
<InformationalPopover feature="lora"> <InformationalPopover feature="lora">
<FormLabel>{t('models.concepts')} </FormLabel> <FormLabel>{t('models.lora')} </FormLabel>
</InformationalPopover> </InformationalPopover>
<Combobox <Combobox
placeholder={placeholder} placeholder={placeholder}

View File

@ -0,0 +1,228 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models'; import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { ModelInstallQueueItem } from './ModelInstallQueueItem'; import { ImportQueueItem } from './ImportQueueItem';
export const ModelInstallQueue = () => { export const ImportQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useListModelInstallsQuery(); const { data } = useGetModelImportsQuery();
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation(); const [pruneModelImports] = usePruneModelImportsMutation();
const pruneCompletedModelInstalls = useCallback(() => { const pruneQueue = useCallback(() => {
_pruneCompletedModelInstalls() pruneModelImports()
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -41,7 +41,7 @@ export const ModelInstallQueue = () => {
); );
} }
}); });
}, [_pruneCompletedModelInstalls, dispatch]); }, [pruneModelImports, dispatch]);
const pruneAvailable = useMemo(() => { const pruneAvailable = useMemo(() => {
return data?.some( return data?.some(
@ -53,19 +53,14 @@ export const ModelInstallQueue = () => {
<Flex flexDir="column" p={3} h="full"> <Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text> <Text>{t('modelManager.importQueue')}</Text>
<Button <Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
size="sm"
isDisabled={!pruneAvailable}
onClick={pruneCompletedModelInstalls}
tooltip={t('modelManager.pruneTooltip')}
>
{t('modelManager.prune')} {t('modelManager.prune')}
</Button> </Button>
</Flex> </Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent> <ScrollableContent>
<Flex flexDir="column-reverse" gap="2"> <Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)} {data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
</Box> </Box>

View File

@ -6,24 +6,17 @@ import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = { const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' }, waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' }, downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' }, running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' }, completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' }, error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' }, cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
}; };
const ModelInstallQueueBadge = ({ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const { t } = useTranslation(); const { t } = useTranslation();
if (!status || !Object.keys(STATUSES).includes(status)) { if (!status) {
return null; return <></>;
} }
return ( return (
@ -32,4 +25,4 @@ const ModelInstallQueueBadge = ({
</Tooltip> </Tooltip>
); );
}; };
export default memo(ModelInstallQueueBadge); export default memo(ImportQueueBadge);

View File

@ -3,16 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { useCancelModelInstallMutation } from 'services/api/endpoints/models'; import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types'; import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ModelInstallQueueBadge from './ModelInstallQueueBadge'; import ImportQueueBadge from './ImportQueueBadge';
type ModelListItemProps = { type ModelListItemProps = {
installJob: ModelInstallJob; model: ModelInstallJob;
}; };
const formatBytes = (bytes: number) => { const formatBytes = (bytes: number) => {
@ -27,26 +26,26 @@ const formatBytes = (bytes: number) => {
return `${bytes.toFixed(2)} ${units[i]}`; return `${bytes.toFixed(2)} ${units[i]}`;
}; };
export const ModelInstallQueueItem = (props: ModelListItemProps) => { export const ImportQueueItem = (props: ModelListItemProps) => {
const { installJob } = props; const { model } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [deleteImportModel] = useCancelModelInstallMutation(); const [deleteImportModel] = useDeleteModelImportMutation();
const source = useMemo(() => { const source = useMemo(() => {
if (installJob.source.type === 'hf') { if (model.source.type === 'hf') {
return installJob.source as HFModelSource; return model.source as HFModelSource;
} else if (installJob.source.type === 'local') { } else if (model.source.type === 'local') {
return installJob.source as LocalModelSource; return model.source as LocalModelSource;
} else if (installJob.source.type === 'url') { } else if (model.source.type === 'url') {
return installJob.source as URLModelSource; return model.source as URLModelSource;
} else { } else {
return installJob.source as LocalModelSource; return model.source as LocalModelSource;
} }
}, [installJob.source]); }, [model.source]);
const handleDeleteModelImport = useCallback(() => { const handleDeleteModelImport = useCallback(() => {
deleteImportModel(installJob.id) deleteImportModel(model.id)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -70,7 +69,7 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
); );
} }
}); });
}, [deleteImportModel, installJob, dispatch]); }, [deleteImportModel, model, dispatch]);
const modelName = useMemo(() => { const modelName = useMemo(() => {
switch (source.type) { switch (source.type) {
@ -86,23 +85,19 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
}, [source]); }, [source]);
const progressValue = useMemo(() => { const progressValue = useMemo(() => {
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) { if (model.bytes === undefined || model.total_bytes === undefined) {
return null;
}
if (installJob.total_bytes === 0) {
return 0; return 0;
} }
return (installJob.bytes / installJob.total_bytes) * 100; return (model.bytes / model.total_bytes) * 100;
}, [installJob.bytes, installJob.total_bytes]); }, [model.bytes, model.total_bytes]);
const progressString = useMemo(() => { const progressString = useMemo(() => {
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) { if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
return ''; return '';
} }
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`; return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [installJob.bytes, installJob.total_bytes, installJob.status]); }, [model.bytes, model.total_bytes, model.status]);
return ( return (
<Flex gap="2" w="full" alignItems="center"> <Flex gap="2" w="full" alignItems="center">
@ -114,21 +109,19 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
<Flex flexDir="column" flex={1}> <Flex flexDir="column" flex={1}>
<Tooltip label={progressString}> <Tooltip label={progressString}>
<Progress <Progress
value={progressValue ?? 0} value={progressValue}
isIndeterminate={progressValue === null} isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')} aria-label={t('accessibility.invokeProgressBar')}
h={2} h={2}
/> />
</Tooltip> </Tooltip>
</Flex> </Flex>
<Box minW="100px" textAlign="center"> <Box minW="100px" textAlign="center">
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} /> <ImportQueueBadge status={model.status} errorReason={model.error_reason} />
</Box> </Box>
<Box minW="20px"> <Box minW="20px">
{(installJob.status === 'downloading' || {(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton <IconButton
isRound={true} isRound={true}
size="xs" size="xs"

View File

@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useLazyScanFolderQuery } from 'services/api/endpoints/models'; import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanFolderResults'; import { ScanModelsResults } from './ScanModelsResults';
export const ScanModelsForm = () => { export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState(''); const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState(''); const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation(); const { t } = useTranslation();
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery(); const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const scanFolder = useCallback(async () => { const handleSubmitScan = useCallback(async () => {
_scanFolder({ scan_path: scanPath }).catch((error) => { _scanModels({ scan_path: scanPath }).catch((error) => {
if (error) { if (error) {
setErrorMessage(error.data.detail); setErrorMessage(error.data.detail);
} }
}); });
}, [_scanFolder, scanPath]); }, [_scanModels, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setScanPath(e.target.value); setScanPath(e.target.value);
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
<Input value={scanPath} onChange={handleSetScanPath} /> <Input value={scanPath} onChange={handleSetScanPath} />
</Flex> </Flex>
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}> <Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
{t('modelManager.scanFolder')} {t('modelManager.scanFolder')}
</Button> </Button>
</Flex> </Flex>

View File

@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models'; import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanFolderResultItem'; import { ScanModelResultItem } from './ScanModelResultItem';
type ScanModelResultsProps = { type ScanModelResultsProps = {
results: ScanFolderResponse; results: ScanFolderResponse;

View File

@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
location: string; location: string;
}; };
export const InstallModelForm = () => { export const SimpleImport = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation(); const [installModel, { isLoading }] = useInstallModelMutation();

Some files were not shown because too many files have changed in this diff Show More