mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
separate-g
...
Invoke-Upd
Author | SHA1 | Date | |
---|---|---|---|
c953051eae |
14
.github/actions/install-frontend-deps/action.yml
vendored
14
.github/actions/install-frontend-deps/action.yml
vendored
@ -1,33 +1,33 @@
|
||||
name: install frontend dependencies
|
||||
name: Install frontend dependencies
|
||||
description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 18
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: setup pnpm
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
- name: Get pnpm store directory
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
|
||||
- name: setup cache
|
||||
uses: actions/cache@v4
|
||||
- uses: actions/cache@v3
|
||||
name: Setup pnpm cache
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: install frontend dependencies
|
||||
- name: Install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
shell: bash
|
||||
working-directory: invokeai/frontend/web
|
||||
|
11
.github/actions/install-python-deps/action.yml
vendored
Normal file
11
.github/actions/install-python-deps/action.yml
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
name: Install python dependencies
|
||||
description: Install python dependencies with pip, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: Setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
28
.github/pr_labels.yml
vendored
28
.github/pr_labels.yml
vendored
@ -1,59 +1,59 @@
|
||||
root:
|
||||
Root:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '*'
|
||||
|
||||
python-deps:
|
||||
PythonDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'pyproject.toml'
|
||||
|
||||
python:
|
||||
Python:
|
||||
- changed-files:
|
||||
- all-globs-to-any-file:
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
|
||||
python-tests:
|
||||
PythonTests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
ci-cd:
|
||||
CICD:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: .github/**
|
||||
|
||||
docker:
|
||||
Docker:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docker/**
|
||||
|
||||
installer:
|
||||
Installer:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: installer/**
|
||||
|
||||
docs:
|
||||
Documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docs/**
|
||||
|
||||
invocations:
|
||||
Invocations:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/invocations/**'
|
||||
|
||||
backend:
|
||||
Backend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/backend/**'
|
||||
|
||||
api:
|
||||
Api:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/api/**'
|
||||
|
||||
services:
|
||||
Services:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/services/**'
|
||||
|
||||
frontend-deps:
|
||||
FrontendDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*/package.json'
|
||||
- '**/*/pnpm-lock.yaml'
|
||||
|
||||
frontend:
|
||||
Frontend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/frontend/web/**'
|
||||
|
45
.github/workflows/build-installer.yml
vendored
45
.github/workflows/build-installer.yml
vendored
@ -1,45 +0,0 @@
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
|
||||
name: build installer
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
build-installer:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <2 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install pypa/build
|
||||
run: pip install --upgrade build
|
||||
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
43
.github/workflows/check-frontend.yml
vendored
Normal file
43
.github/workflows/check-frontend.yml
vendored
Normal file
@ -0,0 +1,43 @@
|
||||
# This workflow runs the frontend code quality checks.
|
||||
#
|
||||
# It may be triggered via dispatch, or by another workflow.
|
||||
|
||||
name: 'Check: frontend'
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
check-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: Run tsc check
|
||||
run: 'pnpm run lint:tsc'
|
||||
shell: bash
|
||||
|
||||
- name: Run dpdm check
|
||||
run: 'pnpm run lint:dpdm'
|
||||
shell: bash
|
||||
|
||||
- name: Run eslint check
|
||||
run: 'pnpm run lint:eslint'
|
||||
shell: bash
|
||||
|
||||
- name: Run prettier check
|
||||
run: 'pnpm run lint:prettier'
|
||||
shell: bash
|
||||
|
||||
- name: Run knip check
|
||||
run: 'pnpm run lint:knip'
|
||||
shell: bash
|
72
.github/workflows/check-pytest.yml
vendored
Normal file
72
.github/workflows/check-pytest.yml
vendored
Normal file
@ -0,0 +1,72 @@
|
||||
# This workflow runs pytest on the codebase in a matrix of platforms.
|
||||
#
|
||||
# It may be triggered via dispatch, or by another workflow.
|
||||
|
||||
name: 'Check: pytest'
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 30 # expected run time: <10 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
id: run-pytest
|
||||
run: pytest
|
33
.github/workflows/check-python.yml
vendored
Normal file
33
.github/workflows/check-python.yml
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
# This workflow runs the python code quality checks.
|
||||
#
|
||||
# It may be triggered via dispatch, or by another workflow.
|
||||
#
|
||||
# TODO: Add mypy or pyright to the checks.
|
||||
|
||||
name: 'Check: python'
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
check-backend:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install python dependencies
|
||||
uses: ./.github/actions/install-python-deps
|
||||
|
||||
- name: Install ruff
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format --check .
|
||||
shell: bash
|
68
.github/workflows/frontend-checks.yml
vendored
68
.github/workflows/frontend-checks.yml
vendored
@ -1,68 +0,0 @@
|
||||
# Runs frontend code quality checks.
|
||||
#
|
||||
# Checks for changes to frontend files before running the checks.
|
||||
# When manually triggered or when called from another workflow, always runs the checks.
|
||||
|
||||
name: 'frontend checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: tsc
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:tsc'
|
||||
shell: bash
|
||||
|
||||
- name: dpdm
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:dpdm'
|
||||
shell: bash
|
||||
|
||||
- name: eslint
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:eslint'
|
||||
shell: bash
|
||||
|
||||
- name: prettier
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:prettier'
|
||||
shell: bash
|
||||
|
||||
- name: knip
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:knip'
|
||||
shell: bash
|
48
.github/workflows/frontend-tests.yml
vendored
48
.github/workflows/frontend-tests.yml
vendored
@ -1,48 +0,0 @@
|
||||
# Runs frontend tests.
|
||||
#
|
||||
# Checks for changes to frontend files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
|
||||
name: 'frontend tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: vitest
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm test:no-watch'
|
||||
shell: bash
|
12
.github/workflows/label-pr.yml
vendored
12
.github/workflows/label-pr.yml
vendored
@ -1,6 +1,6 @@
|
||||
name: 'label PRs'
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
labeler:
|
||||
@ -9,10 +9,8 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: label PRs
|
||||
uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
configuration-path: .github/pr_labels.yml
|
||||
configuration-path: .github/pr_labels.yml
|
23
.github/workflows/mkdocs-material.yml
vendored
23
.github/workflows/mkdocs-material.yml
vendored
@ -21,29 +21,18 @@ jobs:
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: set cache id
|
||||
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
|
||||
- name: use cache
|
||||
uses: actions/cache@v4
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
|
||||
- name: install dependencies
|
||||
run: python -m pip install ".[docs]"
|
||||
|
||||
- name: build & deploy
|
||||
run: mkdocs gh-deploy --force
|
||||
- run: python -m pip install ".[docs]"
|
||||
- run: mkdocs gh-deploy --force
|
||||
|
39
.github/workflows/on-change-check-frontend.yml
vendored
Normal file
39
.github/workflows/on-change-check-frontend.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
# This workflow runs of `check-frontend.yml` on push or pull request.
|
||||
#
|
||||
# The actual checks are in a separate workflow to support simpler workflow
|
||||
# composition without awkward or complicated conditionals.
|
||||
|
||||
name: 'On change: run check-frontend'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
|
||||
jobs:
|
||||
check-changed-frontend-files:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
frontend_any_changed: ${{ steps.changed-files.outputs.frontend_any_changed }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for changed frontend files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
run-check-frontend:
|
||||
needs: check-changed-frontend-files
|
||||
if: ${{ needs.check-changed-frontend-files.outputs.frontend_any_changed == 'true' }}
|
||||
uses: ./.github/workflows/check-frontend.yml
|
42
.github/workflows/on-change-check-python.yml
vendored
Normal file
42
.github/workflows/on-change-check-python.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
# This workflow runs of `check-python.yml` on push or pull request.
|
||||
#
|
||||
# The actual checks are in a separate workflow to support simpler workflow
|
||||
# composition without awkward or complicated conditionals.
|
||||
|
||||
name: 'On change: run check-python'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
|
||||
jobs:
|
||||
check-changed-python-files:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
run-check-python:
|
||||
needs: check-changed-python-files
|
||||
if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }}
|
||||
uses: ./.github/workflows/check-python.yml
|
42
.github/workflows/on-change-pytest.yml
vendored
Normal file
42
.github/workflows/on-change-pytest.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
# This workflow runs of `check-pytest.yml` on push or pull request.
|
||||
#
|
||||
# The actual checks are in a separate workflow to support simpler workflow
|
||||
# composition without awkward or complicated conditionals.
|
||||
|
||||
name: 'On change: run pytest'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
|
||||
jobs:
|
||||
check-changed-python-files:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
run-pytest:
|
||||
needs: check-changed-python-files
|
||||
if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }}
|
||||
uses: ./.github/workflows/check-pytest.yml
|
64
.github/workflows/python-checks.yml
vendored
64
.github/workflows/python-checks.yml
vendored
@ -1,64 +0,0 @@
|
||||
# Runs python code quality checks.
|
||||
#
|
||||
# Checks for changes to python files before running the checks.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
#
|
||||
# TODO: Add mypy or pyright to the checks.
|
||||
|
||||
name: 'python checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: ruff format --check .
|
||||
shell: bash
|
94
.github/workflows/python-tests.yml
vendored
94
.github/workflows/python-tests.yml
vendored
@ -1,94 +0,0 @@
|
||||
# Runs python tests on a matrix of python versions and platforms.
|
||||
#
|
||||
# Checks for changes to python files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
|
||||
name: 'python tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: pytest
|
103
.github/workflows/release.yml
vendored
103
.github/workflows/release.yml
vendored
@ -1,96 +1,103 @@
|
||||
# Main release workflow. Triggered on tag push or manual trigger.
|
||||
#
|
||||
# - Runs all code checks and tests
|
||||
# - Verifies the app version matches the tag version.
|
||||
# - Builds the installer and build, uploading them as artifacts.
|
||||
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
|
||||
#
|
||||
# See docs/RELEASE.md for more information on the release process.
|
||||
|
||||
name: release
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
skip_code_checks:
|
||||
description: 'Skip code checks'
|
||||
required: true
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
jobs:
|
||||
check-version:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check python version
|
||||
uses: samuelcolvin/check-python-version@v4
|
||||
- uses: samuelcolvin/check-python-version@v4
|
||||
id: check-python-version
|
||||
with:
|
||||
version_file_path: invokeai/version/invokeai_version.py
|
||||
|
||||
frontend-checks:
|
||||
uses: ./.github/workflows/frontend-checks.yml
|
||||
check-frontend:
|
||||
if: github.event.inputs.skip_code_checks != 'true'
|
||||
uses: ./.github/workflows/check-frontend.yml
|
||||
|
||||
frontend-tests:
|
||||
uses: ./.github/workflows/frontend-tests.yml
|
||||
check-python:
|
||||
if: github.event.inputs.skip_code_checks != 'true'
|
||||
uses: ./.github/workflows/check-python.yml
|
||||
|
||||
python-checks:
|
||||
uses: ./.github/workflows/python-checks.yml
|
||||
|
||||
python-tests:
|
||||
uses: ./.github/workflows/python-tests.yml
|
||||
check-pytest:
|
||||
if: github.event.inputs.skip_code_checks != 'true'
|
||||
uses: ./.github/workflows/check-pytest.yml
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install python dependencies
|
||||
uses: ./.github/actions/install-python-deps
|
||||
|
||||
- name: Install pypa/build
|
||||
run: pip install --upgrade build
|
||||
|
||||
- name: Setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: Run create_installer.sh
|
||||
id: create_installer
|
||||
run: ./create_installer.sh --skip_frontend_checks
|
||||
working-directory: installer
|
||||
|
||||
- name: Upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: Upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
needs: [check-version, check-frontend, check-python, check-pytest, build]
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
environment:
|
||||
name: testpypi
|
||||
url: https://test.pypi.org/p/invokeai
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
- name: Download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to TestPyPI
|
||||
- name: Publish distribution to TestPyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
|
||||
publish-pypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
needs: [check-version, check-frontend, check-python, check-pytest, build]
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/invokeai
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
- name: Download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to PyPI
|
||||
- name: Publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
@ -23,13 +23,13 @@ It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
|
||||
The release may also be dispatched [manually].
|
||||
The release may also be run [manually].
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
The publish jobs run if the 5 concurrent jobs all succeed and if/when the publish jobs are approved.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
@ -43,16 +43,17 @@ This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
This is our test suite.
|
||||
|
||||
- **`check-pytest`**: runs `pytest` on matrix of platforms
|
||||
- **`check-python`**: runs `ruff` (format and lint)
|
||||
- **`check-frontend`**: runs `prettier` (format), `eslint` (lint), `madge` (circular refs) and `tsc` (static type check)
|
||||
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
|
||||
#### `build-installer` Job
|
||||
#### `build` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
@ -61,7 +62,7 @@ This sets up both python and frontend dependencies and builds the python package
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
At this point, the release workflow pauses (the remaining jobs all require approval).
|
||||
|
||||
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
|
||||
|
||||
@ -69,7 +70,7 @@ A maintainer should go to the **Summary** tab of the workflow, download the inst
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
The publish jobs will skip if any of the previous jobs skip or fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
@ -118,17 +119,13 @@ Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
## Manually Running the Release Workflow
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
The release workflow can be run manually. This is useful to get an installer build and test it out without needing to push a tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
When run this way, you'll see **Skip code checks** checkbox. This allows the workflow to run without the time-consuming 3 code quality check jobs.
|
||||
|
||||
## Manual Release
|
||||
|
||||
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
The publish jobs will skip if the workflow was run manually.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
@ -139,4 +136,4 @@ This functionality is available as a fallback in case something goes wonky. Typi
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
[manually]: #manual-release
|
||||
[manually]: #manually-running-the-release-workflow
|
||||
|
@ -32,6 +32,7 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@ -62,21 +63,23 @@ provides the following fields:
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | Unique identifier for the model |
|
||||
| `name` | str | Name of the model (not unique) |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model.
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@ -91,6 +94,7 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@ -170,7 +174,7 @@ store = context.services.model_manager.store
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
To create a new `ModelRecordService` database or open an existing one,
|
||||
you can directly create either a `ModelRecordServiceSQL` or a
|
||||
@ -213,27 +217,27 @@ for use in the InvokeAI web server. Its signature is:
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
|
||||
The way it works is as follows:
|
||||
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
4. If `model_config_db` is None, then retrieve the legacy
|
||||
`conf_path` option from `invokeai.yaml` and use the Path
|
||||
indicated there. This will default to `configs/models.yaml`.
|
||||
|
||||
|
||||
So a typical startup pattern would be:
|
||||
|
||||
```
|
||||
@ -251,7 +255,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@ -268,28 +272,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@ -308,11 +312,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str
|
||||
#### version() -> str:
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@ -333,7 +337,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@ -348,7 +352,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@ -366,31 +370,31 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
* Registering a model config record for a model already located on the
|
||||
- Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
@ -423,8 +427,8 @@ queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@ -439,6 +443,7 @@ required parameters:
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@ -452,15 +457,15 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
@ -486,9 +491,9 @@ following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
@ -508,13 +513,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
@ -523,7 +528,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
@ -550,7 +555,7 @@ The full list of arguments to `import_model()` is as follows:
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
@ -561,6 +566,7 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@ -619,6 +625,7 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@ -654,6 +661,7 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@ -667,13 +675,14 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@ -693,13 +702,14 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur _before_ the running event.
|
||||
that downloading events occur *before* the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@ -742,13 +752,14 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
@ -763,11 +774,11 @@ attributes. Here is an example of setting the
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
@ -851,6 +862,7 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@ -941,7 +953,7 @@ following fields:
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
|
||||
Every job has a `source` and a `destination`. `source` is a string in
|
||||
the base class, but subclassses redefine it more specifically.
|
||||
@ -962,7 +974,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | **Description** |
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@ -979,7 +991,7 @@ debugging and performance testing.
|
||||
|
||||
In case of an error, the Exception that caused the error will be
|
||||
placed in the `error` field, and the job's status will be set to
|
||||
`DownloadJobStatus.ERROR`.
|
||||
`DownloadJobStatus.ERROR`.
|
||||
|
||||
After an error occurs, any partially downloaded files will be deleted
|
||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||
@ -1028,11 +1040,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@ -1041,12 +1053,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@ -1108,11 +1120,11 @@ A typical initialization sequence will look like:
|
||||
from invokeai.app.services.download_manager import DownloadQueueService
|
||||
|
||||
def log_download_event(job: DownloadJobBase):
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
|
||||
queue = DownloadQueueService(
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
```
|
||||
|
||||
Event handlers can be provided to the queue at initialization time as
|
||||
@ -1143,9 +1155,9 @@ To use the former method, follow this example:
|
||||
```
|
||||
job = DownloadJobRemoteSource(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
queue.submit_download_job(job, start=True)
|
||||
```
|
||||
|
||||
@ -1160,13 +1172,13 @@ To have the queue create the job for you, follow this example instead:
|
||||
```
|
||||
job = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
and is equivalent to manually specifying a destination of
|
||||
@ -1175,6 +1187,7 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@ -1187,7 +1200,7 @@ Here is the full list of arguments that can be provided to
|
||||
|
||||
Internally, `create_download_job()` has a little bit of internal logic
|
||||
that looks at the type of the source and selects the right subclass of
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
|
||||
**TODO**: move this logic into its own method for overriding in
|
||||
subclasses.
|
||||
@ -1262,7 +1275,7 @@ for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage
|
||||
### Example Usage:
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -1315,6 +1328,7 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@ -1334,6 +1348,7 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
@ -1400,6 +1415,7 @@ testing suite to avoid hitting the internet.
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
@ -1418,12 +1434,13 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
@ -1518,16 +1535,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
@ -1550,6 +1567,7 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@ -1563,6 +1581,7 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@ -1571,14 +1590,14 @@ the model. Use it like this:
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@ -1590,15 +1609,15 @@ following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
@ -1705,7 +1724,6 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
@ -19,8 +19,6 @@ their descriptions.
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
|
@ -9,7 +9,7 @@ set INVOKEAI_ROOT=.
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Run textual inversion training
|
||||
echo 2. Invoke Model Training
|
||||
echo 3. Merge models (diffusers type only)
|
||||
echo 4. Download and install models
|
||||
echo 5. Change InvokeAI startup options
|
||||
@ -25,8 +25,7 @@ IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
echo To use Invoke Training for LoRA, TI, and more - Visit https://github.com/invoke-ai/invoke-training
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
|
@ -59,8 +59,7 @@ do_choice() {
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Textual inversion training\n"
|
||||
invokeai-ti --gui $PARAMS
|
||||
printf "To use Invoke Training for LoRA, TI, and more - Visit https://github.com/invoke-ai/invoke-training\n"
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
@ -118,7 +117,7 @@ do_choice() {
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Textual inversion training"
|
||||
2 "Run Invoke Training"
|
||||
3 "Merge models (diffusers type only)"
|
||||
4 "Download and install models"
|
||||
5 "Change InvokeAI startup options"
|
||||
@ -151,7 +150,7 @@ do_line_input() {
|
||||
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Run textual inversion training\n"
|
||||
printf "2: Run Invoke Training\n"
|
||||
printf "3: Merge models (diffusers type only)\n"
|
||||
printf "4: Download and install models\n"
|
||||
printf "5: Change InvokeAI startup options\n"
|
||||
|
@ -26,6 +26,7 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@ -92,9 +93,10 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
|
@ -3,7 +3,9 @@
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -13,10 +15,13 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -25,6 +30,8 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -40,6 +47,15 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@ -50,16 +66,19 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config_path": "string",
|
||||
"config": "string",
|
||||
"key": "string",
|
||||
"hash": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"converted_at": 0,
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@ -68,12 +87,50 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@ -153,16 +210,48 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
@ -234,6 +323,19 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@ -250,13 +352,15 @@ async def scan_for_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -268,14 +372,14 @@ async def update_model_record(
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="delete_model",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model(
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@ -296,39 +400,42 @@ async def delete_model(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
@ -344,7 +451,6 @@ async def delete_model(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -387,7 +493,6 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
@ -403,10 +508,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@ -420,8 +525,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@ -430,7 +536,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install/{id}",
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@ -450,7 +556,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install/{id}",
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@ -468,8 +574,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@ -548,8 +654,7 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
store.update_model(key, config=model_config)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@ -558,7 +663,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"original_hash": model_config.original_hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@ -566,6 +671,10 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@ -577,66 +686,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -173,16 +173,6 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
@ -203,42 +193,38 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
else:
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
)
|
||||
|
||||
|
||||
@ -374,6 +360,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@ -383,6 +370,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
@ -789,7 +777,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
|
@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -256,7 +256,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
|
@ -18,9 +18,10 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
@ -150,13 +151,6 @@ ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
@ -266,6 +260,7 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -352,7 +347,6 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@ -398,7 +392,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, and/or `image_size`.
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
@ -7,6 +7,7 @@ import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -20,15 +21,11 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -38,14 +35,12 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
@ -95,6 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -143,7 +139,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@ -153,11 +148,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@ -183,14 +177,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@ -285,7 +278,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||
self._models_installed.clear()
|
||||
@ -379,18 +372,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -476,7 +466,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
self.record_store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@ -528,14 +518,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
key = self._create_key()
|
||||
if config and not config.get("key", None):
|
||||
config["key"] = key
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
@ -545,11 +543,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
if hasattr(info, "config"):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info)
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
return info.key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@ -570,15 +568,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
inplace=source.inplace,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
@ -606,17 +602,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@ -631,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
source: ModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@ -659,7 +653,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@ -846,11 +840,3 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -6,19 +6,20 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -59,33 +60,11 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -109,12 +88,13 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -129,6 +109,40 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
@ -203,3 +217,21 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
@ -43,7 +43,7 @@ import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -53,11 +53,12 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@ -68,7 +69,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -77,13 +78,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -93,19 +95,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@ -113,12 +119,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@ -126,7 +132,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
return self.get_model(key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@ -140,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -152,20 +158,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
UPDATE model_config
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@ -192,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -213,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -239,8 +246,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@ -257,13 +265,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@ -272,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@ -283,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@ -298,35 +307,83 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
select count(*) from model_config;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@ -337,7 +394,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
@ -1,35 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class SessionRunnerBase(ABC):
|
||||
"""
|
||||
Base class for session runner.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||
"""Starts the session runner"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Completes the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs an already prepared node on the session"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
|
@ -2,14 +2,13 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -17,164 +16,15 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation.id, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run after a node is executed"""
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||
try:
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item)
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
"""Processes sessions from the session queue"""
|
||||
|
||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||
super().__init__()
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
@ -209,7 +59,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
"cancel_event": self._cancel_event,
|
||||
},
|
||||
)
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@ -268,34 +117,130 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If we have a on_before_run_session callback, call it
|
||||
if self.on_before_run_session is not None:
|
||||
self.on_before_run_session(self._queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# If we have a on_after_run_session callback, call it
|
||||
if self.on_after_run_session is not None:
|
||||
self.on_after_run_session(self._queue_item)
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
@ -329,4 +274,3 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self._invoker.services.logger.debug("Session processor stopped")
|
||||
|
@ -9,7 +9,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -36,7 +35,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -1,88 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
@ -3,6 +3,7 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -21,7 +22,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -72,27 +73,19 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@ -102,7 +95,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@ -150,14 +143,9 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.hash,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
@ -17,8 +17,7 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
55
invokeai/app/util/metadata.py
Normal file
55
invokeai/app/util/metadata.py
Normal file
@ -0,0 +1,55 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
@ -25,13 +25,10 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
@ -59,8 +56,8 @@ class ModelType(str, Enum):
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@ -76,9 +73,9 @@ class SubModelType(str, Enum):
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
@ -96,8 +93,8 @@ class ModelFormat(str, Enum):
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
@ -115,186 +112,127 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||
hash: str = Field(description="The hash of the model file(s).")
|
||||
path: str = Field(
|
||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||
)
|
||||
name: str = Field(description="Name of the model.")
|
||||
base: BaseModelType = Field(description="The base model.")
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
path: str = Field(description="filesystem path to the model file or directory")
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
type: ModelType = Field(description="type of the model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
schema["required"].extend(
|
||||
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
|
||||
)
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
|
||||
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
|
||||
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
|
||||
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
@ -304,10 +242,6 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
@ -315,65 +249,58 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
class T2IConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
if isinstance(v, dict):
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
# This is a known issue. Please see:
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@ -405,6 +332,6 @@ class ModelConfigFactory(object):
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.converted_at = timestamp
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
return model # type: ignore
|
||||
|
@ -11,175 +11,56 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
from imohash import hashfile
|
||||
|
||||
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
else:
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
|
||||
self._file_filter = file_filter or self._default_file_filter
|
||||
|
||||
def hash(self, model_path: Union[str, Path]) -> str:
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||
Fasthash a single file and return its hexdigest.
|
||||
|
||||
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the model
|
||||
:param model_location: Path to the model file
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
model_path = Path(model_path)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
|
||||
def _hash_dir(self, dir: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
|
||||
Args:
|
||||
dir: Path to the directory
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the directory
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||
"""Return a list of all model files in the directory.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||
|
||||
Returns:
|
||||
List of all model files in the directory
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use
|
||||
|
||||
Returns:
|
||||
A function that hashes a file using the given algorithm
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
|
@ -13,7 +13,6 @@ from invokeai.backend.model_manager import (
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
@ -51,7 +50,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
@ -81,7 +80,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
@ -120,7 +119,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
subfolder=submodel_type.value if submodel_type else None,
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
|
@ -15,8 +15,10 @@ Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
@ -25,6 +27,8 @@ from ..config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
@ -57,9 +61,6 @@ class ModelLoaderRegistryBase(ABC):
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
@ -70,10 +71,10 @@ class ModelLoaderRegistry:
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
@ -89,15 +90,33 @@ class ModelLoaderRegistry:
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, config, submodel_type
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -21,15 +20,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -38,13 +37,13 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
|
@ -3,10 +3,9 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -42,7 +41,6 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
result = None
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
@ -66,7 +64,6 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
@ -78,7 +75,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
@ -86,8 +83,8 @@ class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# TODO(psyche): the types on this diffusers method are not correct
|
||||
# Diffusers doesn't provide typing info
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
@ -4,8 +4,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -17,7 +16,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -55,11 +54,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -74,7 +73,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config_path
|
||||
config_file = config.config
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
|
@ -3,9 +3,9 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -13,25 +13,24 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -39,15 +38,16 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
|
@ -16,7 +16,6 @@ from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from . import (
|
||||
@ -118,6 +117,7 @@ class ModelMerger(object):
|
||||
config = self._installer.app_config
|
||||
store = self._installer.record_store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||
|
||||
assert (
|
||||
@ -134,6 +134,10 @@ class ModelMerger(object):
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
model_paths.extend([config.models_path / info.path])
|
||||
@ -159,10 +163,12 @@ class ModelMerger(object):
|
||||
|
||||
# update model's config
|
||||
model_config = self._installer.record_store.get_model(key)
|
||||
model_config.name = merged_model_name
|
||||
model_config.description = f"Merge of models {', '.join(model_names)}"
|
||||
|
||||
self._installer.record_store.update_model(
|
||||
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
|
||||
model_config.update(
|
||||
{
|
||||
"name": merged_model_name,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"vae": vae,
|
||||
}
|
||||
)
|
||||
self._installer.record_store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
@ -25,7 +25,9 @@ from .metadata_base import (
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
@ -36,8 +38,10 @@ __all__ = [
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
|
@ -23,21 +23,22 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
@ -51,13 +52,10 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
StringSetAdapter = TypeAdapter(set[str])
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
@ -65,7 +63,6 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@ -105,21 +102,22 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
@ -136,23 +134,36 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=self._get_url_with_api_key(url),
|
||||
url=url,
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_phrases: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
trigger_phrases=trigger_phrases,
|
||||
api_response=json.dumps(version_json),
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
@ -163,14 +174,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
version = self._requests.get(version_url).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
@ -178,11 +189,6 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
||||
|
@ -13,7 +13,6 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -24,7 +23,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
@ -61,7 +60,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||
# If not available, then set variant to None and get the default.
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
while not model_info:
|
||||
try:
|
||||
@ -74,24 +72,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
else:
|
||||
variant = None
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
@ -14,8 +14,10 @@ versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -32,6 +34,31 @@ class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
@ -45,6 +72,8 @@ class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
@ -82,16 +111,60 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
@ -120,7 +193,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
|
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -8,7 +8,6 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -18,12 +17,11 @@ from .config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import ModelHash
|
||||
from .hash import FastModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -97,8 +95,8 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
@ -110,6 +108,14 @@ class ModelProbe(object):
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
@ -131,21 +137,19 @@ class ModelProbe(object):
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
if format_type == "diffusers":
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
@ -157,17 +161,15 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
@ -177,7 +179,7 @@ class ModelProbe(object):
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.ONNX,
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
@ -186,7 +188,7 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
model_info = ModelConfigFactory.make_config(fields, key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
@ -211,11 +213,11 @@ class ModelProbe(object):
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.VAE
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.LoRA
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.LoRA
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -237,7 +239,7 @@ class ModelProbe(object):
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.LoRA
|
||||
return ModelType.Lora
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
@ -283,21 +285,13 @@ class ModelProbe(object):
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
if model_type == ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type is ModelType.ControlNet:
|
||||
elif model_type == ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"../stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../stable-diffusion/v2-inference.yaml"
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
@ -503,12 +497,12 @@ class FolderProbeBase(ProbeBase):
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OpenVINO
|
||||
return ModelRepoVariant.OPENVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.Flax
|
||||
return ModelRepoVariant.FLAX
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.Default
|
||||
return ModelRepoVariant.DEFAULT
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
@ -714,8 +708,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
@ -723,8 +717,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
@ -13,7 +13,6 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
@ -35,7 +34,7 @@ def filter_files(
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
variant = variant or ModelRepoVariant.Default
|
||||
variant = variant or ModelRepoVariant.DEFAULT
|
||||
paths: List[Path] = []
|
||||
root = files[0].parts[0]
|
||||
|
||||
@ -74,81 +73,64 @@ def filter_files(
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubfolderCandidate:
|
||||
path: Path
|
||||
score: int
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
result.add(path)
|
||||
|
||||
elif "openvino_model" in path.name:
|
||||
if variant == ModelRepoVariant.OpenVINO:
|
||||
if variant == ModelRepoVariant.OPENVINO:
|
||||
result.add(path)
|
||||
|
||||
elif "flax_model" in path.name:
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
if variant == ModelRepoVariant.FLAX:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif variant in [
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.Default,
|
||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
||||
# text encoders:
|
||||
#
|
||||
# - text_encoder/model.fp16.safetensors
|
||||
# - text_encoder/model.safetensors
|
||||
# - text_encoder/pytorch_model.bin
|
||||
# - text_encoder/pytorch_model.fp16.bin
|
||||
#
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
score = 0
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
subfolder_weights[parent] = []
|
||||
|
||||
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
if (
|
||||
variant
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
and not any(variant.value in x.name for x in result)
|
||||
):
|
||||
return set()
|
||||
|
@ -4,11 +4,13 @@ Initialization file for the invokeai.backend.stable_diffusion package
|
||||
|
||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"AttentionMapSaver",
|
||||
"set_seamless",
|
||||
]
|
||||
|
@ -12,6 +12,7 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@ -25,9 +26,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -38,6 +39,7 @@ class PipelineIntermediateState:
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -188,6 +190,19 @@ class T2IAdapterData:
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@ -328,9 +343,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
return latents, None
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
@ -370,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
@ -387,7 +402,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if mask is not None and not gradient_mask:
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
return latents
|
||||
return latents, attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
@ -400,22 +415,23 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
return latents, attention_map_saver
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
extra_conditioning_info=conditioning_data.extra,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
@ -466,6 +482,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
@ -475,10 +498,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
attention_map_saver=attention_map_saver,
|
||||
)
|
||||
)
|
||||
|
||||
return latents
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
@ -520,9 +544,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
||||
|
||||
# Handle ControlNet(s)
|
||||
# Handle ControlNet(s) and T2I-Adapter(s)
|
||||
down_block_additional_residuals = None
|
||||
mid_block_additional_residual = None
|
||||
down_intrablock_additional_residuals = None
|
||||
# if control_data is not None and t2i_adapter_data is not None:
|
||||
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
||||
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
||||
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
||||
# elif control_data is not None:
|
||||
if control_data is not None:
|
||||
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
|
||||
control_data=control_data,
|
||||
@ -532,9 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
)
|
||||
|
||||
# Handle T2I-Adapter(s)
|
||||
down_intrablock_additional_residuals = None
|
||||
# elif t2i_adapter_data is not None:
|
||||
if t2i_adapter_data is not None:
|
||||
accum_adapter_state = None
|
||||
for single_t2i_adapter_data in t2i_adapter_data:
|
||||
@ -560,6 +588,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||
|
||||
# down_block_additional_residuals = accum_adapter_state
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
@ -568,6 +597,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
# extra:
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
|
@ -2,4 +2,6 @@
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
|
@ -21,7 +21,11 @@ class ExtraConditioningInfo:
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
embeds: torch.Tensor
|
||||
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
||||
# should only be stored in one place.
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
@ -79,6 +83,7 @@ class ConditioningData:
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
extra: Optional[ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
|
@ -3,13 +3,19 @@
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...util import torch_dtype
|
||||
|
||||
@ -19,14 +25,72 @@ class CrossAttentionType(enum.Enum):
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class CrossAttnControlContext:
|
||||
def __init__(self, arguments: Arguments):
|
||||
class Context:
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = (1,)
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: Arguments, step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||
self.cross_attention_mask = None
|
||||
self.cross_attention_index_map = None
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
self.self_cross_attention_module_identifiers = []
|
||||
self.tokens_cross_attention_module_identifiers = []
|
||||
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||
|
||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||
|
||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.SAVE
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||
return False
|
||||
|
||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.APPLY
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
@ -47,8 +111,219 @@ class CrossAttnControlContext:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
slice: torch.Tensor,
|
||||
dim: Optional[int],
|
||||
offset: int,
|
||||
slice_size: Optional[int],
|
||||
):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
"dim": dim,
|
||||
"slice_size": slice_size,
|
||||
"slices": {offset or 0: slice},
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||
def get_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
requested_dim: Optional[int],
|
||||
requested_offset: int,
|
||||
slice_size: int,
|
||||
):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict["dim"] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict["slices"][0]
|
||||
|
||||
if saved_attention_dict["dim"] == requested_dim:
|
||||
if slice_size != saved_attention_dict["slice_size"]:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
||||
)
|
||||
return saved_attention_dict["slices"][requested_offset]
|
||||
|
||||
if saved_attention_dict["dim"] is None:
|
||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention["dim"], saved_attention["slice_size"]
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
if cleanup:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(
|
||||
self,
|
||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current Attention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
"""
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query.shape[0],
|
||||
query.shape[1],
|
||||
key.shape[1],
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == "cuda":
|
||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == "mps" or q.device.type == "cpu":
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
def restore_default_cross_attention(
|
||||
model,
|
||||
is_running_diffusers: bool,
|
||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
||||
):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -87,6 +362,170 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, cross_attention_class) and which_attn in name
|
||||
]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||
"attention map display will not work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
context.save_slice(
|
||||
module.identifier,
|
||||
slice_to_save,
|
||||
dim=dim,
|
||||
offset=offset,
|
||||
slice_size=slice_size,
|
||||
)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for _identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, "name"): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
# only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
# default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||
return hidden_states
|
||||
|
||||
|
||||
## 🧨diffusers implementation follows
|
||||
|
||||
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class AttnProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
@ -94,6 +533,18 @@ class SwapCrossAttnContext:
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def __int__(
|
||||
self,
|
||||
cac_types_to_do: [CrossAttentionType],
|
||||
modified_text_embeddings: torch.Tensor,
|
||||
index_map: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
self.cross_attention_types_to_do = cac_types_to_do
|
||||
self.modified_text_embeddings = modified_text_embeddings
|
||||
self.index_map = index_map
|
||||
self.mask = mask
|
||||
|
||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||
return attn_type in self.cross_attention_types_to_do
|
||||
|
||||
|
@ -0,0 +1,100 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f"{key}_{maps.shape[1]}"
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
if pil_image is not None:
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for _key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(
|
||||
torch.swapdims(maps, 0, 1),
|
||||
[num_tokens, this_maps_height, this_maps_width],
|
||||
)
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(
|
||||
maps_normalized_expanded_clamped,
|
||||
[num_tokens * latents_height, latents_width],
|
||||
)
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
@ -17,11 +17,13 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
CrossAttnControlContext,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@ -67,12 +69,14 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = CrossAttnControlContext(
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
@ -83,6 +87,27 @@ class InvokeAIDiffuserComponent:
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
@ -199,47 +224,51 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data, # TODO: type
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
**kwargs,
|
||||
):
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
)
|
||||
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||
# control is currently only supported in sequential mode.
|
||||
if wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@ -306,15 +335,7 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
@ -362,10 +383,8 @@ class InvokeAIDiffuserComponent:
|
||||
both_conditionings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@ -375,17 +394,14 @@ class InvokeAIDiffuserComponent:
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
**kwargs,
|
||||
):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||
slower execution speed.
|
||||
"""
|
||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||
# and T2I-Adapter residuals into two chunks.
|
||||
# low-memory sequential path
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
@ -394,6 +410,7 @@ class InvokeAIDiffuserComponent:
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
if down_intrablock_additional_residuals is not None:
|
||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||
for down_intrablock in down_intrablock_additional_residuals:
|
||||
@ -402,29 +419,12 @@ class InvokeAIDiffuserComponent:
|
||||
cond_down_intrablock.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
||||
cross_attn_processor_context = None
|
||||
if self.cross_attention_control_context is not None:
|
||||
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
||||
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
||||
# will be populated before the conditioned pass.
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
||||
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
||||
mask=self.cross_attention_control_context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
#####################
|
||||
# Unconditioned pass
|
||||
#####################
|
||||
|
||||
# Run unconditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
@ -434,11 +434,6 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
@ -447,7 +442,6 @@ class InvokeAIDiffuserComponent:
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
@ -457,15 +451,11 @@ class InvokeAIDiffuserComponent:
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
###################
|
||||
# Conditioned pass
|
||||
###################
|
||||
|
||||
# Run conditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
@ -475,12 +465,6 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
@ -488,7 +472,6 @@ class InvokeAIDiffuserComponent:
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
@ -498,6 +481,89 @@ class InvokeAIDiffuserComponent:
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
_uncond_down, _cond_down = down_block.chunk(2)
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
if down_intrablock_additional_residuals is not None:
|
||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||
for down_intrablock in down_intrablock_additional_residuals:
|
||||
_uncond_down, _cond_down = down_intrablock.chunk(2)
|
||||
uncond_down_intrablock.append(_uncond_down)
|
||||
cond_down_intrablock.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# no cross attention for unconditioning (negative prompt)
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@ -567,3 +633,54 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = math.ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index * 2
|
||||
chunk_size = min(2, len(conditionings) - offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.LoRA,
|
||||
model_type=ModelType.Lora,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
@ -30,7 +30,7 @@
|
||||
"lint:prettier": "prettier --check .",
|
||||
"lint:tsc": "tsc --noEmit",
|
||||
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
|
||||
"fix": "eslint --fix . && prettier --log-level warn --write .",
|
||||
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
|
@ -134,6 +134,8 @@
|
||||
"loadMore": "Mehr laden",
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"preparingDownload": "bereite Download vor",
|
||||
"preparingDownloadFailed": "Problem beim Download vorbereiten",
|
||||
"deleteImage": "Lösche Bild",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
@ -965,7 +967,7 @@
|
||||
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
|
||||
"pruneFailed": "Problem beim leeren der Warteschlange",
|
||||
"pauseTooltip": "Prozess anhalten",
|
||||
"back": "Ende",
|
||||
"back": "Hinten",
|
||||
"resumeSucceeded": "Prozess wird fortgesetzt",
|
||||
"resumeTooltip": "Prozess wieder aufnehmen",
|
||||
"time": "Zeit",
|
||||
|
@ -78,7 +78,6 @@
|
||||
"aboutDesc": "Using Invoke for work? Check out:",
|
||||
"aboutHeading": "Own Your Creative Power",
|
||||
"accept": "Accept",
|
||||
"add": "Add",
|
||||
"advanced": "Advanced",
|
||||
"advancedOptions": "Advanced Options",
|
||||
"ai": "ai",
|
||||
@ -304,12 +303,6 @@
|
||||
"method": "High Resolution Fix Method"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noPromptTriggers": "No triggers available",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
@ -741,8 +734,6 @@
|
||||
"customConfig": "Custom Config",
|
||||
"customConfigFileLocation": "Custom Config File Location",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"defaultSettings": "Default Settings",
|
||||
"defaultSettingsSaved": "Default Settings Saved",
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
@ -777,7 +768,6 @@
|
||||
"mergedModelName": "Merged Model Name",
|
||||
"mergedModelSaveLocation": "Save Location",
|
||||
"mergeModels": "Merge Models",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
"modelAdded": "Model Added",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
@ -849,12 +839,9 @@
|
||||
"statusConverting": "Converting",
|
||||
"syncModels": "Sync Models",
|
||||
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
@ -873,7 +860,6 @@
|
||||
"models": {
|
||||
"addLora": "Add LoRA",
|
||||
"allLoRAsAdded": "All LoRAs added",
|
||||
"concepts": "Concepts",
|
||||
"loraAlreadyAdded": "LoRA already added",
|
||||
"esrganModel": "ESRGAN Model",
|
||||
"loading": "loading",
|
||||
|
@ -505,6 +505,8 @@
|
||||
"seamLowThreshold": "Bajo",
|
||||
"coherencePassHeader": "Parámetros de la coherencia",
|
||||
"compositingSettingsHeader": "Ajustes de la composición",
|
||||
"coherenceSteps": "Pasos",
|
||||
"coherenceStrength": "Fuerza",
|
||||
"patchmatchDownScaleSize": "Reducir a escala",
|
||||
"coherenceMode": "Modo"
|
||||
},
|
||||
|
@ -114,8 +114,7 @@
|
||||
"checkpoint": "Checkpoint",
|
||||
"safetensors": "Safetensors",
|
||||
"ai": "ia",
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere"
|
||||
"file": "File"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -143,6 +142,8 @@
|
||||
"copy": "Copia",
|
||||
"download": "Scarica",
|
||||
"setCurrentImage": "Imposta come immagine corrente",
|
||||
"preparingDownload": "Preparazione del download",
|
||||
"preparingDownloadFailed": "Problema durante la preparazione del download",
|
||||
"downloadSelection": "Scarica gli elementi selezionati",
|
||||
"noImageSelected": "Nessuna immagine selezionata",
|
||||
"deleteSelection": "Elimina la selezione",
|
||||
@ -608,6 +609,8 @@
|
||||
"seamLowThreshold": "Basso",
|
||||
"seamHighThreshold": "Alto",
|
||||
"coherencePassHeader": "Passaggio di coerenza",
|
||||
"coherenceSteps": "Passi",
|
||||
"coherenceStrength": "Forza",
|
||||
"compositingSettingsHeader": "Impostazioni di composizione",
|
||||
"patchmatchDownScaleSize": "Ridimensiona",
|
||||
"coherenceMode": "Modalità",
|
||||
@ -1397,6 +1400,19 @@
|
||||
"Regola la maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Passi",
|
||||
"paragraphs": [
|
||||
"Numero di passi utilizzati nel Passaggio di Coerenza.",
|
||||
"Simile ai passi di generazione."
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Sfocatura",
|
||||
"paragraphs": [
|
||||
"Il raggio di sfocatura della maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Modalità",
|
||||
"paragraphs": [
|
||||
@ -1415,6 +1431,13 @@
|
||||
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Forza",
|
||||
"paragraphs": [
|
||||
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
|
||||
"Simile alla forza di riduzione del rumore."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",
|
||||
|
@ -123,6 +123,8 @@
|
||||
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
|
||||
"loading": "불러오는 중",
|
||||
"unableToLoad": "갤러리를 로드할 수 없음",
|
||||
"preparingDownload": "다운로드 준비",
|
||||
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
|
||||
"singleColumnLayout": "단일 열 레이아웃",
|
||||
"image": "이미지",
|
||||
"loadMore": "더 불러오기",
|
||||
|
@ -97,6 +97,8 @@
|
||||
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
|
||||
"loading": "Bezig met laden",
|
||||
"unableToLoad": "Kan galerij niet laden",
|
||||
"preparingDownload": "Bezig met voorbereiden van download",
|
||||
"preparingDownloadFailed": "Fout bij voorbereiden van download",
|
||||
"downloadSelection": "Download selectie",
|
||||
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
|
||||
"copy": "Kopieer",
|
||||
@ -533,6 +535,8 @@
|
||||
"coherencePassHeader": "Coherentiestap",
|
||||
"maskBlur": "Vervaag",
|
||||
"maskBlurMethod": "Vervagingsmethode",
|
||||
"coherenceSteps": "Stappen",
|
||||
"coherenceStrength": "Sterkte",
|
||||
"seamHighThreshold": "Hoog",
|
||||
"seamLowThreshold": "Laag",
|
||||
"invoke": {
|
||||
@ -1135,6 +1139,13 @@
|
||||
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Stappen",
|
||||
"paragraphs": [
|
||||
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
|
||||
"Gelijk aan de hoofdparameter Stappen."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Dynamische prompts vormt een enkele prompt om in vele.",
|
||||
@ -1149,6 +1160,12 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Vervaging",
|
||||
"paragraphs": [
|
||||
"De vervagingsstraal van het masker."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Het aantal te genereren afbeeldingen.",
|
||||
@ -1223,6 +1240,13 @@
|
||||
],
|
||||
"heading": "Ontruisingssterkte"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Sterkte",
|
||||
"paragraphs": [
|
||||
"Ontruisingssterkte voor de coherentiefase.",
|
||||
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",
|
||||
|
@ -143,6 +143,8 @@
|
||||
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
|
||||
"loading": "Загрузка",
|
||||
"unableToLoad": "Невозможно загрузить галерею",
|
||||
"preparingDownload": "Подготовка к скачиванию",
|
||||
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
|
||||
"image": "изображение",
|
||||
"drop": "перебросить",
|
||||
"problemDeletingImages": "Проблема с удалением изображений",
|
||||
@ -610,7 +612,9 @@
|
||||
"maskBlurMethod": "Метод размытия",
|
||||
"seamLowThreshold": "Низкий",
|
||||
"seamHighThreshold": "Высокий",
|
||||
"coherenceSteps": "Шагов",
|
||||
"coherencePassHeader": "Порог Coherence",
|
||||
"coherenceStrength": "Сила",
|
||||
"compositingSettingsHeader": "Настройки компоновки",
|
||||
"invoke": {
|
||||
"noNodesInGraph": "Нет узлов в графе",
|
||||
@ -1317,6 +1321,13 @@
|
||||
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Шаги",
|
||||
"paragraphs": [
|
||||
"Количество шагов снижения шума, используемых при прохождении когерентности.",
|
||||
"То же, что и основной параметр «Шаги»."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Динамические запросы превращают одно приглашение на множество.",
|
||||
@ -1331,6 +1342,12 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Размытие",
|
||||
"paragraphs": [
|
||||
"Радиус размытия маски."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Количество изображений, которые нужно сгенерировать.",
|
||||
@ -1405,6 +1422,13 @@
|
||||
],
|
||||
"heading": "Шумоподавление"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Сила",
|
||||
"paragraphs": [
|
||||
null,
|
||||
"То же, что параметр «Сила шумоподавления img2img»."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",
|
||||
|
@ -355,6 +355,7 @@
|
||||
"starImage": "Yıldız Koy",
|
||||
"download": "İndir",
|
||||
"deleteSelection": "Seçileni Sil",
|
||||
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
|
||||
"problemDeletingImages": "Görsel Silmede Sorun",
|
||||
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
|
||||
"galleryImageResetSize": "Boyutu Resetle",
|
||||
@ -376,6 +377,7 @@
|
||||
"setCurrentImage": "Çalışma Görseli Yap",
|
||||
"unableToLoad": "Galeri Yüklenemedi",
|
||||
"downloadSelection": "Seçileni İndir",
|
||||
"preparingDownload": "İndirmeye Hazırlanıyor",
|
||||
"singleColumnLayout": "Tek Sütun Düzen",
|
||||
"generations": "Çıktılar",
|
||||
"showUploads": "Yüklenenleri Göster",
|
||||
@ -721,6 +723,7 @@
|
||||
"clipSkip": "CLIP Atlama",
|
||||
"randomizeSeed": "Rastgele Tohum",
|
||||
"cfgScale": "CFG Ölçeği",
|
||||
"coherenceStrength": "Etki",
|
||||
"controlNetControlMode": "Yönetim Kipi",
|
||||
"general": "Genel",
|
||||
"img2imgStrength": "Görselden Görsel Ölçüsü",
|
||||
@ -790,6 +793,7 @@
|
||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||
"cfgRescale": "CFG Rescale",
|
||||
"coherencePassHeader": "Uyum Geçişi",
|
||||
"coherenceSteps": "Adım",
|
||||
"infillMethod": "Doldurma Yöntemi",
|
||||
"maskBlurMethod": "Bulandırma Yöntemi",
|
||||
"steps": "Adım",
|
||||
|
@ -136,6 +136,8 @@
|
||||
"copy": "复制",
|
||||
"download": "下载",
|
||||
"setCurrentImage": "设为当前图像",
|
||||
"preparingDownload": "准备下载",
|
||||
"preparingDownloadFailed": "准备下载时出现问题",
|
||||
"downloadSelection": "下载所选内容",
|
||||
"noImageSelected": "无选中的图像",
|
||||
"deleteSelection": "删除所选内容",
|
||||
@ -614,9 +616,11 @@
|
||||
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
||||
},
|
||||
"patchmatchDownScaleSize": "缩小",
|
||||
"coherenceSteps": "步数",
|
||||
"clipSkip": "CLIP 跳过层",
|
||||
"compositingSettingsHeader": "合成设置",
|
||||
"useCpuNoise": "使用 CPU 噪声",
|
||||
"coherenceStrength": "强度",
|
||||
"enableNoiseSettings": "启用噪声设置",
|
||||
"coherenceMode": "模式",
|
||||
"cpuNoise": "CPU 噪声",
|
||||
@ -1398,6 +1402,19 @@
|
||||
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。"
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "步数",
|
||||
"paragraphs": [
|
||||
"一致性层中使用的去噪步数。",
|
||||
"与主参数中的步数相同。"
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "模糊",
|
||||
"paragraphs": [
|
||||
"遮罩模糊半径。"
|
||||
]
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
"heading": "使用 CPU 噪声",
|
||||
"paragraphs": [
|
||||
@ -1450,6 +1467,13 @@
|
||||
"第二轮去噪有助于合成内补/外扩图像。"
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "强度",
|
||||
"paragraphs": [
|
||||
"一致性层使用的去噪强度。",
|
||||
"去噪强度与图生图的参数相同。"
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",
|
||||
|
@ -55,8 +55,6 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -153,7 +151,5 @@ addFirstListImagesListener(startAppListening);
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Prompts
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -7,10 +7,8 @@ import {
|
||||
selectAllT2IAdapters,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||
@ -26,9 +24,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
const currentModel = getState().generation.model;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
@ -43,29 +39,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
|
@ -1,96 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setScheduler,
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterPrecision,
|
||||
isParameterScheduler,
|
||||
isParameterSteps,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||
|
||||
if (!modelConfig || !modelConfig.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||
const vaeArray = map(data?.entities);
|
||||
const validVae = vaeArray.find((model) => model.key === vae);
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||
},
|
||||
});
|
||||
};
|
@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
dispatch(api.util.invalidateTags(['Model']));
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
|
||||
},
|
||||
});
|
||||
|
||||
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id, error, error_type } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
|
@ -1,5 +1,4 @@
|
||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
@ -83,8 +82,6 @@ export type AppConfig = {
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler;
|
||||
vaePrecision?: ParameterPrecision;
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
|
@ -8,15 +8,15 @@ type Props = {
|
||||
onOpen: () => void;
|
||||
};
|
||||
|
||||
export const AddPromptTriggerButton = memo((props: Props) => {
|
||||
export const AddEmbeddingButton = memo((props: Props) => {
|
||||
const { onOpen, isOpen } = props;
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('prompt.addPromptTrigger')}>
|
||||
<Tooltip label={t('embedding.addEmbedding')}>
|
||||
<IconButton
|
||||
variant="promptOverlay"
|
||||
isDisabled={isOpen}
|
||||
aria-label={t('prompt.addPromptTrigger')}
|
||||
aria-label={t('embedding.addEmbedding')}
|
||||
icon={<PiCodeBold />}
|
||||
onClick={onOpen}
|
||||
/>
|
||||
@ -24,4 +24,4 @@ export const AddPromptTriggerButton = memo((props: Props) => {
|
||||
);
|
||||
});
|
||||
|
||||
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';
|
||||
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
|
@ -1,9 +1,9 @@
|
||||
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
|
||||
import type { PromptPopoverProps } from 'features/prompt/types';
|
||||
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
|
||||
import type { EmbeddingPopoverProps } from 'features/embedding/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
const { onSelect, isOpen, onClose, width, children } = props;
|
||||
|
||||
return (
|
||||
@ -14,7 +14,7 @@ export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={false}
|
||||
returnFocusOnClose={true}
|
||||
isLazy
|
||||
>
|
||||
<PopoverAnchor>{children}</PopoverAnchor>
|
||||
@ -27,11 +27,11 @@ export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
borderStyle="solid"
|
||||
>
|
||||
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
||||
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
|
||||
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
PromptPopover.displayName = 'PromptPopover';
|
||||
EmbeddingPopover.displayName = 'EmbeddingPopover';
|
@ -0,0 +1,21 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||
import type { EmbeddingSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof EmbeddingSelect> = {
|
||||
title: 'Feature/Prompt/EmbeddingSelect',
|
||||
tags: ['autodocs'],
|
||||
component: EmbeddingSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||
|
||||
const Component = (props: EmbeddingSelectProps) => {
|
||||
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -0,0 +1,67 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
onSelect(embedding.name);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-embedding"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
@ -1,12 +1,12 @@
|
||||
import type { PropsWithChildren } from 'react';
|
||||
|
||||
export type PromptTriggerSelectProps = {
|
||||
export type EmbeddingSelectProps = {
|
||||
onSelect: (v: string) => void;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
export type PromptPopoverProps = PropsWithChildren &
|
||||
PromptTriggerSelectProps & {
|
||||
export type EmbeddingPopoverProps = PropsWithChildren &
|
||||
EmbeddingSelectProps & {
|
||||
isOpen: boolean;
|
||||
width?: number | string;
|
||||
};
|
@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
|
||||
import { useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
|
||||
type UseInsertTriggerArg = {
|
||||
type UseInsertEmbeddingArg = {
|
||||
prompt: string;
|
||||
textareaRef: RefObject<HTMLTextAreaElement>;
|
||||
onChange: (v: string) => void;
|
||||
};
|
||||
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
[_onChange]
|
||||
);
|
||||
|
||||
const insertTrigger = useCallback(
|
||||
const insertEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!textareaRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the trigger
|
||||
// this is where we insert the TI trigger
|
||||
const caret = textareaRef.current.selectionStart;
|
||||
|
||||
if (isNil(caret)) {
|
||||
@ -35,9 +35,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
newPrompt += `${v}`;
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
// we insert the cursor after the end of trigger
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
@ -47,7 +51,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
_onChange(newPrompt);
|
||||
});
|
||||
|
||||
// set the cursor position to just after the trigger
|
||||
// set the caret position to just after the TI trigger
|
||||
textareaRef.current.selectionStart = finalCaretPos;
|
||||
textareaRef.current.selectionEnd = finalCaretPos;
|
||||
},
|
||||
@ -58,17 +62,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
textareaRef.current?.focus();
|
||||
}, [textareaRef]);
|
||||
|
||||
const handleClosePopover = useCallback(() => {
|
||||
const handleClose = useCallback(() => {
|
||||
onClose();
|
||||
onFocus();
|
||||
}, [onFocus, onClose]);
|
||||
|
||||
const onSelect = useCallback(
|
||||
const onSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
insertTrigger(v);
|
||||
handleClosePopover();
|
||||
insertEmbedding(v);
|
||||
handleClose();
|
||||
},
|
||||
[handleClosePopover, insertTrigger]
|
||||
[handleClose, insertEmbedding]
|
||||
);
|
||||
|
||||
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@ -86,7 +90,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
isOpen,
|
||||
onClose,
|
||||
onOpen,
|
||||
onSelect,
|
||||
onSelectEmbedding,
|
||||
onKeyDown,
|
||||
onFocus,
|
||||
};
|
@ -59,7 +59,7 @@ const LoRASelect = () => {
|
||||
return (
|
||||
<FormControl isDisabled={!options.length}>
|
||||
<InformationalPopover feature="lora">
|
||||
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||
<FormLabel>{t('models.lora')} </FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox
|
||||
placeholder={placeholder}
|
||||
|
@ -0,0 +1,228 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { isNil, omitBy } from 'lodash-es';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
export const AdvancedImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
setValue,
|
||||
resetField,
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<AnyModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'diffusers',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
installModel({
|
||||
source: values.path,
|
||||
config: omitBy(values, isNil),
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[installModel, dispatch, t, reset]
|
||||
);
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
useEffect(() => {
|
||||
if (watchedModelType === 'main') {
|
||||
setValue('format', 'diffusers');
|
||||
setValue('repo_variant', '');
|
||||
setValue('variant', 'normal');
|
||||
}
|
||||
if (watchedModelType === 'lora') {
|
||||
setValue('format', 'lycoris');
|
||||
} else if (watchedModelType === 'embedding') {
|
||||
setValue('format', 'embedding_file');
|
||||
} else if (watchedModelType === 'ip_adapter') {
|
||||
setValue('format', 'invokeai');
|
||||
} else {
|
||||
setValue('format', 'diffusers');
|
||||
}
|
||||
resetField('upcast_attention');
|
||||
resetField('ztsnr_training');
|
||||
resetField('vae');
|
||||
resetField('config');
|
||||
resetField('prediction_type');
|
||||
resetField('image_encoder_model_id');
|
||||
}, [watchedModelType, resetField, setValue]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||
<Flex alignItems="flex-end" gap="4">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
</FormControl>
|
||||
<Text px="2" fontSize="xs" textAlign="center">
|
||||
{t('modelManager.advancedImportInfo')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
</ScrollableContent>
|
||||
);
|
||||
};
|
@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
import { ImportQueueItem } from './ImportQueueItem';
|
||||
|
||||
export const ModelInstallQueue = () => {
|
||||
export const ImportQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useListModelInstallsQuery();
|
||||
const { data } = useGetModelImportsQuery();
|
||||
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
const [pruneModelImports] = usePruneModelImportsMutation();
|
||||
|
||||
const pruneCompletedModelInstalls = useCallback(() => {
|
||||
_pruneCompletedModelInstalls()
|
||||
const pruneQueue = useCallback(() => {
|
||||
pruneModelImports()
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -41,7 +41,7 @@ export const ModelInstallQueue = () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [_pruneCompletedModelInstalls, dispatch]);
|
||||
}, [pruneModelImports, dispatch]);
|
||||
|
||||
const pruneAvailable = useMemo(() => {
|
||||
return data?.some(
|
||||
@ -53,19 +53,14 @@ export const ModelInstallQueue = () => {
|
||||
<Flex flexDir="column" p={3} h="full">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text>{t('modelManager.importQueue')}</Text>
|
||||
<Button
|
||||
size="sm"
|
||||
isDisabled={!pruneAvailable}
|
||||
onClick={pruneCompletedModelInstalls}
|
||||
tooltip={t('modelManager.pruneTooltip')}
|
||||
>
|
||||
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
||||
{t('modelManager.prune')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column-reverse" gap="2">
|
||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
@ -6,24 +6,17 @@ import type { ModelInstallStatus } from 'services/api/types';
|
||||
const STATUSES = {
|
||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||
};
|
||||
|
||||
const ModelInstallQueueBadge = ({
|
||||
status,
|
||||
errorReason,
|
||||
}: {
|
||||
status?: ModelInstallStatus;
|
||||
errorReason?: string | null;
|
||||
}) => {
|
||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
return null;
|
||||
if (!status) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
return (
|
||||
@ -32,4 +25,4 @@ const ModelInstallQueueBadge = ({
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
export default memo(ModelInstallQueueBadge);
|
||||
export default memo(ImportQueueBadge);
|
@ -3,16 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||
|
||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||
import ImportQueueBadge from './ImportQueueBadge';
|
||||
|
||||
type ModelListItemProps = {
|
||||
installJob: ModelInstallJob;
|
||||
model: ModelInstallJob;
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
@ -27,26 +26,26 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||
|
||||
const source = useMemo(() => {
|
||||
if (installJob.source.type === 'hf') {
|
||||
return installJob.source as HFModelSource;
|
||||
} else if (installJob.source.type === 'local') {
|
||||
return installJob.source as LocalModelSource;
|
||||
} else if (installJob.source.type === 'url') {
|
||||
return installJob.source as URLModelSource;
|
||||
if (model.source.type === 'hf') {
|
||||
return model.source as HFModelSource;
|
||||
} else if (model.source.type === 'local') {
|
||||
return model.source as LocalModelSource;
|
||||
} else if (model.source.type === 'url') {
|
||||
return model.source as URLModelSource;
|
||||
} else {
|
||||
return installJob.source as LocalModelSource;
|
||||
return model.source as LocalModelSource;
|
||||
}
|
||||
}, [installJob.source]);
|
||||
}, [model.source]);
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel(installJob.id)
|
||||
deleteImportModel(model.id)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -70,7 +69,7 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [deleteImportModel, installJob, dispatch]);
|
||||
}, [deleteImportModel, model, dispatch]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (source.type) {
|
||||
@ -86,23 +85,19 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
}, [source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (installJob.total_bytes === 0) {
|
||||
if (model.bytes === undefined || model.total_bytes === undefined) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||
}, [installJob.bytes, installJob.total_bytes]);
|
||||
return (model.bytes / model.total_bytes) * 100;
|
||||
}, [model.bytes, model.total_bytes]);
|
||||
|
||||
const progressString = useMemo(() => {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center">
|
||||
@ -114,21 +109,19 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue ?? 0}
|
||||
isIndeterminate={progressValue === null}
|
||||
value={progressValue}
|
||||
isIndeterminate={progressValue === undefined}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Box minW="100px" textAlign="center">
|
||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||
</Box>
|
||||
|
||||
<Box minW="20px">
|
||||
{(installJob.status === 'downloading' ||
|
||||
installJob.status === 'waiting' ||
|
||||
installJob.status === 'running') && (
|
||||
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
import { ScanModelsResults } from './ScanModelsResults';
|
||||
|
||||
export const ScanModelsForm = () => {
|
||||
const [scanPath, setScanPath] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
||||
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||
|
||||
const scanFolder = useCallback(async () => {
|
||||
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
||||
const handleSubmitScan = useCallback(async () => {
|
||||
_scanModels({ scan_path: scanPath }).catch((error) => {
|
||||
if (error) {
|
||||
setErrorMessage(error.data.detail);
|
||||
}
|
||||
});
|
||||
}, [_scanFolder, scanPath]);
|
||||
}, [_scanModels, scanPath]);
|
||||
|
||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setScanPath(e.target.value);
|
||||
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
|
||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
||||
</Flex>
|
||||
|
||||
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
{t('modelManager.scanFolder')}
|
||||
</Button>
|
||||
</Flex>
|
@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelResultItem } from './ScanFolderResultItem';
|
||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||
|
||||
type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export const InstallModelForm = () => {
|
||||
export const SimpleImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user