diff --git a/.github/actions/install-frontend-deps/action.yml b/.github/actions/install-frontend-deps/action.yml index b9d910ca99..32b4987249 100644 --- a/.github/actions/install-frontend-deps/action.yml +++ b/.github/actions/install-frontend-deps/action.yml @@ -1,33 +1,33 @@ -name: Install frontend dependencies +name: install frontend dependencies description: Installs frontend dependencies with pnpm, with caching runs: using: 'composite' steps: - - name: Setup Node 18 + - name: setup node 18 uses: actions/setup-node@v4 with: node-version: '18' - - name: Setup pnpm + - name: setup pnpm uses: pnpm/action-setup@v2 with: version: 8 run_install: false - - name: Get pnpm store directory + - name: get pnpm store directory shell: bash run: | echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV - - uses: actions/cache@v3 - name: Setup pnpm cache + - name: setup cache + uses: actions/cache@v4 with: path: ${{ env.STORE_PATH }} key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }} restore-keys: | ${{ runner.os }}-pnpm-store- - - name: Install frontend dependencies + - name: install frontend dependencies run: pnpm install --prefer-frozen-lockfile shell: bash working-directory: invokeai/frontend/web diff --git a/.github/actions/install-python-deps/action.yml b/.github/actions/install-python-deps/action.yml deleted file mode 100644 index 4c0d351899..0000000000 --- a/.github/actions/install-python-deps/action.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Install python dependencies -description: Install python dependencies with pip, with caching -runs: - using: 'composite' - steps: - - name: Setup python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - cache: pip - cache-dependency-path: pyproject.toml diff --git a/.github/pr_labels.yml b/.github/pr_labels.yml index 9580ccb0be..fdf11a470f 100644 --- a/.github/pr_labels.yml +++ b/.github/pr_labels.yml @@ -1,59 +1,59 @@ -Root: +root: - changed-files: - any-glob-to-any-file: '*' -PythonDeps: +python-deps: - changed-files: - any-glob-to-any-file: 'pyproject.toml' -Python: +python: - changed-files: - all-globs-to-any-file: - 'invokeai/**' - '!invokeai/frontend/web/**' -PythonTests: +python-tests: - changed-files: - any-glob-to-any-file: 'tests/**' -CICD: +ci-cd: - changed-files: - any-glob-to-any-file: .github/** -Docker: +docker: - changed-files: - any-glob-to-any-file: docker/** -Installer: +installer: - changed-files: - any-glob-to-any-file: installer/** -Documentation: +docs: - changed-files: - any-glob-to-any-file: docs/** -Invocations: +invocations: - changed-files: - any-glob-to-any-file: 'invokeai/app/invocations/**' -Backend: +backend: - changed-files: - any-glob-to-any-file: 'invokeai/backend/**' -Api: +api: - changed-files: - any-glob-to-any-file: 'invokeai/app/api/**' -Services: +services: - changed-files: - any-glob-to-any-file: 'invokeai/app/services/**' -FrontendDeps: +frontend-deps: - changed-files: - any-glob-to-any-file: - '**/*/package.json' - '**/*/pnpm-lock.yaml' -Frontend: +frontend: - changed-files: - any-glob-to-any-file: 'invokeai/frontend/web/**' diff --git a/.github/workflows/build-installer.yml b/.github/workflows/build-installer.yml new file mode 100644 index 0000000000..9d6d42c8d9 --- /dev/null +++ b/.github/workflows/build-installer.yml @@ -0,0 +1,45 @@ +# Builds and uploads the installer and python build artifacts. + +name: build installer + +on: + workflow_dispatch: + workflow_call: + +jobs: + build-installer: + runs-on: ubuntu-latest + timeout-minutes: 5 # expected run time: <2 min + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: setup python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: pip + cache-dependency-path: pyproject.toml + + - name: install pypa/build + run: pip install --upgrade build + + - name: setup frontend + uses: ./.github/actions/install-frontend-deps + + - name: create installer + id: create_installer + run: ./create_installer.sh + working-directory: installer + + - name: upload python distribution artifact + uses: actions/upload-artifact@v4 + with: + name: dist + path: ${{ steps.create_installer.outputs.DIST_PATH }} + + - name: upload installer artifact + uses: actions/upload-artifact@v4 + with: + name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }} + path: ${{ steps.create_installer.outputs.INSTALLER_PATH }} diff --git a/.github/workflows/check-frontend.yml b/.github/workflows/check-frontend.yml deleted file mode 100644 index 8134926556..0000000000 --- a/.github/workflows/check-frontend.yml +++ /dev/null @@ -1,43 +0,0 @@ -# This workflow runs the frontend code quality checks. -# -# It may be triggered via dispatch, or by another workflow. - -name: 'Check: frontend' - -on: - workflow_dispatch: - workflow_call: - -defaults: - run: - working-directory: invokeai/frontend/web - -jobs: - check-frontend: - runs-on: ubuntu-latest - timeout-minutes: 10 # expected run time: <2 min - steps: - - uses: actions/checkout@v4 - - - name: Set up frontend - uses: ./.github/actions/install-frontend-deps - - - name: Run tsc check - run: 'pnpm run lint:tsc' - shell: bash - - - name: Run dpdm check - run: 'pnpm run lint:dpdm' - shell: bash - - - name: Run eslint check - run: 'pnpm run lint:eslint' - shell: bash - - - name: Run prettier check - run: 'pnpm run lint:prettier' - shell: bash - - - name: Run knip check - run: 'pnpm run lint:knip' - shell: bash diff --git a/.github/workflows/check-pytest.yml b/.github/workflows/check-pytest.yml deleted file mode 100644 index aedc0e59c3..0000000000 --- a/.github/workflows/check-pytest.yml +++ /dev/null @@ -1,72 +0,0 @@ -# This workflow runs pytest on the codebase in a matrix of platforms. -# -# It may be triggered via dispatch, or by another workflow. - -name: 'Check: pytest' - -on: - workflow_dispatch: - workflow_call: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - matrix: - strategy: - matrix: - python-version: - - '3.10' - pytorch: - - linux-cuda-11_7 - - linux-rocm-5_2 - - linux-cpu - - macos-default - - windows-cpu - include: - - pytorch: linux-cuda-11_7 - os: ubuntu-22.04 - github-env: $GITHUB_ENV - - pytorch: linux-rocm-5_2 - os: ubuntu-22.04 - extra-index-url: 'https://download.pytorch.org/whl/rocm5.2' - github-env: $GITHUB_ENV - - pytorch: linux-cpu - os: ubuntu-22.04 - extra-index-url: 'https://download.pytorch.org/whl/cpu' - github-env: $GITHUB_ENV - - pytorch: macos-default - os: macOS-12 - github-env: $GITHUB_ENV - - pytorch: windows-cpu - os: windows-2022 - github-env: $env:GITHUB_ENV - name: ${{ matrix.pytorch }} on ${{ matrix.python-version }} - runs-on: ${{ matrix.os }} - timeout-minutes: 30 # expected run time: <10 min, depending on platform - env: - PIP_USE_PEP517: '1' - steps: - - uses: actions/checkout@v4 - - - name: set test prompt to main branch validation - run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }} - - - name: setup python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: pyproject.toml - - - name: install invokeai - env: - PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }} - run: > - pip3 install - --editable=".[test]" - - - name: run pytest - id: run-pytest - run: pytest diff --git a/.github/workflows/check-python.yml b/.github/workflows/check-python.yml deleted file mode 100644 index 63a6c46b0a..0000000000 --- a/.github/workflows/check-python.yml +++ /dev/null @@ -1,33 +0,0 @@ -# This workflow runs the python code quality checks. -# -# It may be triggered via dispatch, or by another workflow. -# -# TODO: Add mypy or pyright to the checks. - -name: 'Check: python' - -on: - workflow_dispatch: - workflow_call: - -jobs: - check-backend: - runs-on: ubuntu-latest - timeout-minutes: 5 # expected run time: <1 min - steps: - - uses: actions/checkout@v4 - - - name: Install python dependencies - uses: ./.github/actions/install-python-deps - - - name: Install ruff - run: pip install ruff - shell: bash - - - name: Ruff check - run: ruff check --output-format=github . - shell: bash - - - name: Ruff format - run: ruff format --check . - shell: bash diff --git a/.github/workflows/frontend-checks.yml b/.github/workflows/frontend-checks.yml new file mode 100644 index 0000000000..e621348af4 --- /dev/null +++ b/.github/workflows/frontend-checks.yml @@ -0,0 +1,68 @@ +# Runs frontend code quality checks. +# +# Checks for changes to frontend files before running the checks. +# When manually triggered or when called from another workflow, always runs the checks. + +name: 'frontend checks' + +on: + push: + branches: + - 'main' + pull_request: + types: + - 'ready_for_review' + - 'opened' + - 'synchronize' + merge_group: + workflow_dispatch: + workflow_call: + +defaults: + run: + working-directory: invokeai/frontend/web + +jobs: + frontend-checks: + runs-on: ubuntu-latest + timeout-minutes: 10 # expected run time: <2 min + steps: + - uses: actions/checkout@v4 + + - name: check for changed frontend files + if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }} + id: changed-files + uses: tj-actions/changed-files@v42 + with: + files_yaml: | + frontend: + - 'invokeai/frontend/web/**' + + - name: install dependencies + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + uses: ./.github/actions/install-frontend-deps + + - name: tsc + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: 'pnpm lint:tsc' + shell: bash + + - name: dpdm + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: 'pnpm lint:dpdm' + shell: bash + + - name: eslint + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: 'pnpm lint:eslint' + shell: bash + + - name: prettier + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: 'pnpm lint:prettier' + shell: bash + + - name: knip + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: 'pnpm lint:knip' + shell: bash diff --git a/.github/workflows/frontend-tests.yml b/.github/workflows/frontend-tests.yml new file mode 100644 index 0000000000..e4e18f2571 --- /dev/null +++ b/.github/workflows/frontend-tests.yml @@ -0,0 +1,48 @@ +# Runs frontend tests. +# +# Checks for changes to frontend files before running the tests. +# When manually triggered or called from another workflow, always runs the tests. + +name: 'frontend tests' + +on: + push: + branches: + - 'main' + pull_request: + types: + - 'ready_for_review' + - 'opened' + - 'synchronize' + merge_group: + workflow_dispatch: + workflow_call: + +defaults: + run: + working-directory: invokeai/frontend/web + +jobs: + frontend-tests: + runs-on: ubuntu-latest + timeout-minutes: 10 # expected run time: <2 min + steps: + - uses: actions/checkout@v4 + + - name: check for changed frontend files + if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }} + id: changed-files + uses: tj-actions/changed-files@v42 + with: + files_yaml: | + frontend: + - 'invokeai/frontend/web/**' + + - name: install dependencies + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + uses: ./.github/actions/install-frontend-deps + + - name: vitest + if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: 'pnpm test:no-watch' + shell: bash diff --git a/.github/workflows/label-pr.yml b/.github/workflows/label-pr.yml index bc14e2f2c8..1a98512190 100644 --- a/.github/workflows/label-pr.yml +++ b/.github/workflows/label-pr.yml @@ -1,6 +1,6 @@ -name: "Pull Request Labeler" +name: 'label PRs' on: -- pull_request_target + - pull_request_target jobs: labeler: @@ -9,8 +9,10 @@ jobs: pull-requests: write runs-on: ubuntu-latest steps: - - name: Checkout + - name: checkout uses: actions/checkout@v4 - - uses: actions/labeler@v5 + + - name: label PRs + uses: actions/labeler@v5 with: - configuration-path: .github/pr_labels.yml \ No newline at end of file + configuration-path: .github/pr_labels.yml diff --git a/.github/workflows/mkdocs-material.yml b/.github/workflows/mkdocs-material.yml index cbcfbf0835..419d87f37b 100644 --- a/.github/workflows/mkdocs-material.yml +++ b/.github/workflows/mkdocs-material.yml @@ -21,18 +21,29 @@ jobs: SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI' steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - name: checkout + uses: actions/checkout@v4 + + - name: setup python + uses: actions/setup-python@v5 with: python-version: '3.10' cache: pip cache-dependency-path: pyproject.toml - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@v4 + + - name: set cache id + run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + + - name: use cache + uses: actions/cache@v4 with: key: mkdocs-material-${{ env.cache_id }} path: .cache restore-keys: | mkdocs-material- - - run: python -m pip install ".[docs]" - - run: mkdocs gh-deploy --force + + - name: install dependencies + run: python -m pip install ".[docs]" + + - name: build & deploy + run: mkdocs gh-deploy --force diff --git a/.github/workflows/on-change-check-frontend.yml b/.github/workflows/on-change-check-frontend.yml deleted file mode 100644 index 5e8704ad71..0000000000 --- a/.github/workflows/on-change-check-frontend.yml +++ /dev/null @@ -1,39 +0,0 @@ -# This workflow runs of `check-frontend.yml` on push or pull request. -# -# The actual checks are in a separate workflow to support simpler workflow -# composition without awkward or complicated conditionals. - -name: 'On change: run check-frontend' - -on: - push: - branches: - - 'main' - pull_request: - types: - - 'ready_for_review' - - 'opened' - - 'synchronize' - merge_group: - -jobs: - check-changed-frontend-files: - if: github.event.pull_request.draft == false - runs-on: ubuntu-latest - outputs: - frontend_any_changed: ${{ steps.changed-files.outputs.frontend_any_changed }} - steps: - - uses: actions/checkout@v4 - - - name: Check for changed frontend files - id: changed-files - uses: tj-actions/changed-files@v41 - with: - files_yaml: | - frontend: - - 'invokeai/frontend/web/**' - - run-check-frontend: - needs: check-changed-frontend-files - if: ${{ needs.check-changed-frontend-files.outputs.frontend_any_changed == 'true' }} - uses: ./.github/workflows/check-frontend.yml diff --git a/.github/workflows/on-change-check-python.yml b/.github/workflows/on-change-check-python.yml deleted file mode 100644 index e73198b3fa..0000000000 --- a/.github/workflows/on-change-check-python.yml +++ /dev/null @@ -1,42 +0,0 @@ -# This workflow runs of `check-python.yml` on push or pull request. -# -# The actual checks are in a separate workflow to support simpler workflow -# composition without awkward or complicated conditionals. - -name: 'On change: run check-python' - -on: - push: - branches: - - 'main' - pull_request: - types: - - 'ready_for_review' - - 'opened' - - 'synchronize' - merge_group: - -jobs: - check-changed-python-files: - if: github.event.pull_request.draft == false - runs-on: ubuntu-latest - outputs: - python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }} - steps: - - uses: actions/checkout@v4 - - - name: Check for changed python files - id: changed-files - uses: tj-actions/changed-files@v41 - with: - files_yaml: | - python: - - 'pyproject.toml' - - 'invokeai/**' - - '!invokeai/frontend/web/**' - - 'tests/**' - - run-check-python: - needs: check-changed-python-files - if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }} - uses: ./.github/workflows/check-python.yml diff --git a/.github/workflows/on-change-pytest.yml b/.github/workflows/on-change-pytest.yml deleted file mode 100644 index 0c174098bb..0000000000 --- a/.github/workflows/on-change-pytest.yml +++ /dev/null @@ -1,42 +0,0 @@ -# This workflow runs of `check-pytest.yml` on push or pull request. -# -# The actual checks are in a separate workflow to support simpler workflow -# composition without awkward or complicated conditionals. - -name: 'On change: run pytest' - -on: - push: - branches: - - 'main' - pull_request: - types: - - 'ready_for_review' - - 'opened' - - 'synchronize' - merge_group: - -jobs: - check-changed-python-files: - if: github.event.pull_request.draft == false - runs-on: ubuntu-latest - outputs: - python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }} - steps: - - uses: actions/checkout@v4 - - - name: Check for changed python files - id: changed-files - uses: tj-actions/changed-files@v41 - with: - files_yaml: | - python: - - 'pyproject.toml' - - 'invokeai/**' - - '!invokeai/frontend/web/**' - - 'tests/**' - - run-pytest: - needs: check-changed-python-files - if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }} - uses: ./.github/workflows/check-pytest.yml diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml new file mode 100644 index 0000000000..cbf986d8da --- /dev/null +++ b/.github/workflows/python-checks.yml @@ -0,0 +1,64 @@ +# Runs python code quality checks. +# +# Checks for changes to python files before running the checks. +# When manually triggered or called from another workflow, always runs the tests. +# +# TODO: Add mypy or pyright to the checks. + +name: 'python checks' + +on: + push: + branches: + - 'main' + pull_request: + types: + - 'ready_for_review' + - 'opened' + - 'synchronize' + merge_group: + workflow_dispatch: + workflow_call: + +jobs: + python-checks: + runs-on: ubuntu-latest + timeout-minutes: 5 # expected run time: <1 min + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: check for changed python files + if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }} + id: changed-files + uses: tj-actions/changed-files@v42 + with: + files_yaml: | + python: + - 'pyproject.toml' + - 'invokeai/**' + - '!invokeai/frontend/web/**' + - 'tests/**' + + - name: setup python + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: pip + cache-dependency-path: pyproject.toml + + - name: install ruff + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: pip install ruff + shell: bash + + - name: ruff check + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: ruff check --output-format=github . + shell: bash + + - name: ruff format + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: ruff format --check . + shell: bash diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 0000000000..d261a90451 --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,94 @@ +# Runs python tests on a matrix of python versions and platforms. +# +# Checks for changes to python files before running the tests. +# When manually triggered or called from another workflow, always runs the tests. + +name: 'python tests' + +on: + push: + branches: + - 'main' + pull_request: + types: + - 'ready_for_review' + - 'opened' + - 'synchronize' + merge_group: + workflow_dispatch: + workflow_call: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + matrix: + strategy: + matrix: + python-version: + - '3.10' + - '3.11' + platform: + - linux-cuda-11_7 + - linux-rocm-5_2 + - linux-cpu + - macos-default + - windows-cpu + include: + - platform: linux-cuda-11_7 + os: ubuntu-22.04 + github-env: $GITHUB_ENV + - platform: linux-rocm-5_2 + os: ubuntu-22.04 + extra-index-url: 'https://download.pytorch.org/whl/rocm5.2' + github-env: $GITHUB_ENV + - platform: linux-cpu + os: ubuntu-22.04 + extra-index-url: 'https://download.pytorch.org/whl/cpu' + github-env: $GITHUB_ENV + - platform: macos-default + os: macOS-12 + github-env: $GITHUB_ENV + - platform: windows-cpu + os: windows-2022 + github-env: $env:GITHUB_ENV + name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}' + runs-on: ${{ matrix.os }} + timeout-minutes: 15 # expected run time: 2-6 min, depending on platform + env: + PIP_USE_PEP517: '1' + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: check for changed python files + if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }} + id: changed-files + uses: tj-actions/changed-files@v42 + with: + files_yaml: | + python: + - 'pyproject.toml' + - 'invokeai/**' + - '!invokeai/frontend/web/**' + - 'tests/**' + + - name: setup python + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: pyproject.toml + + - name: install dependencies + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + env: + PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }} + run: > + pip3 install --editable=".[test]" + + - name: run pytest + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }} + run: pytest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0f9ca098d5..037a082722 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,103 +1,96 @@ -name: Release +# Main release workflow. Triggered on tag push or manual trigger. +# +# - Runs all code checks and tests +# - Verifies the app version matches the tag version. +# - Builds the installer and build, uploading them as artifacts. +# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval. +# +# See docs/RELEASE.md for more information on the release process. + +name: release on: push: tags: - 'v*' workflow_dispatch: - inputs: - skip_code_checks: - description: 'Skip code checks' - required: true - default: true - type: boolean jobs: check-version: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - name: checkout + uses: actions/checkout@v4 - - uses: samuelcolvin/check-python-version@v4 + - name: check python version + uses: samuelcolvin/check-python-version@v4 id: check-python-version with: version_file_path: invokeai/version/invokeai_version.py - check-frontend: - if: github.event.inputs.skip_code_checks != 'true' - uses: ./.github/workflows/check-frontend.yml + frontend-checks: + uses: ./.github/workflows/frontend-checks.yml - check-python: - if: github.event.inputs.skip_code_checks != 'true' - uses: ./.github/workflows/check-python.yml + frontend-tests: + uses: ./.github/workflows/frontend-tests.yml - check-pytest: - if: github.event.inputs.skip_code_checks != 'true' - uses: ./.github/workflows/check-pytest.yml + python-checks: + uses: ./.github/workflows/python-checks.yml + + python-tests: + uses: ./.github/workflows/python-tests.yml build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install python dependencies - uses: ./.github/actions/install-python-deps - - - name: Install pypa/build - run: pip install --upgrade build - - - name: Setup frontend - uses: ./.github/actions/install-frontend-deps - - - name: Run create_installer.sh - id: create_installer - run: ./create_installer.sh --skip_frontend_checks - working-directory: installer - - - name: Upload python distribution artifact - uses: actions/upload-artifact@v4 - with: - name: dist - path: ${{ steps.create_installer.outputs.DIST_PATH }} - - - name: Upload installer artifact - uses: actions/upload-artifact@v4 - with: - name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }} - path: ${{ steps.create_installer.outputs.INSTALLER_PATH }} + uses: ./.github/workflows/build-installer.yml publish-testpypi: runs-on: ubuntu-latest - needs: [check-version, check-frontend, check-python, check-pytest, build] - if: github.event_name != 'workflow_dispatch' + timeout-minutes: 5 # expected run time: <1 min + needs: + [ + check-version, + frontend-checks, + frontend-tests, + python-checks, + python-tests, + build, + ] environment: name: testpypi url: https://test.pypi.org/p/invokeai steps: - - name: Download distribution from build job + - name: download distribution from build job uses: actions/download-artifact@v4 with: name: dist path: dist/ - - name: Publish distribution to TestPyPI + - name: publish distribution to TestPyPI uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://test.pypi.org/legacy/ publish-pypi: runs-on: ubuntu-latest - needs: [check-version, check-frontend, check-python, check-pytest, build] - if: github.event_name != 'workflow_dispatch' + timeout-minutes: 5 # expected run time: <1 min + needs: + [ + check-version, + frontend-checks, + frontend-tests, + python-checks, + python-tests, + build, + ] environment: name: pypi url: https://pypi.org/p/invokeai steps: - - name: Download distribution from build job + - name: download distribution from build job uses: actions/download-artifact@v4 with: name: dist path: dist/ - - name: Publish distribution to PyPI + - name: publish distribution to PyPI uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 3a0375b027..82bf68b535 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -23,13 +23,13 @@ It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if Run `make tag-release` to tag the current commit and kick off the workflow. -The release may also be run [manually]. +The release may also be dispatched [manually]. ### Workflow Jobs and Process The workflow consists of a number of concurrently-run jobs, and two final publish jobs. -The publish jobs run if the 5 concurrent jobs all succeed and if/when the publish jobs are approved. +The publish jobs require manual approval and are only run if the other jobs succeed. #### `check-version` Job @@ -43,17 +43,16 @@ This job uses [samuelcolvin/check-python-version]. #### Check and Test Jobs -This is our test suite. - -- **`check-pytest`**: runs `pytest` on matrix of platforms -- **`check-python`**: runs `ruff` (format and lint) -- **`check-frontend`**: runs `prettier` (format), `eslint` (lint), `madge` (circular refs) and `tsc` (static type check) +- **`python-tests`**: runs `pytest` on matrix of platforms +- **`python-checks`**: runs `ruff` (format and lint) +- **`frontend-tests`**: runs `vitest` +- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports) > **TODO** We should add `mypy` or `pyright` to the **`check-python`** job. > **TODO** We should add an end-to-end test job that generates an image. -#### `build` Job +#### `build-installer` Job This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts: @@ -62,7 +61,7 @@ This sets up both python and frontend dependencies and builds the python package #### Sanity Check & Smoke Test -At this point, the release workflow pauses (the remaining jobs all require approval). +At this point, the release workflow pauses as the remaining publish jobs require approval. A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates. @@ -70,7 +69,7 @@ A maintainer should go to the **Summary** tab of the workflow, download the inst #### PyPI Publish Jobs -The publish jobs will skip if any of the previous jobs skip or fail. +The publish jobs will run if any of the previous jobs fail. They use [GitHub environments], which are configured as [trusted publishers] on PyPI. @@ -119,13 +118,17 @@ Once the release is published to PyPI, it's time to publish the GitHub release. > **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up. -## Manually Running the Release Workflow +## Manual Build -The release workflow can be run manually. This is useful to get an installer build and test it out without needing to push a tag. +The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag. -When run this way, you'll see **Skip code checks** checkbox. This allows the workflow to run without the time-consuming 3 code quality check jobs. +No checks are run, it just builds. -The publish jobs will skip if the workflow was run manually. +## Manual Release + +The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check. + +This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above. [InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases [PyPI]: https://pypi.org/ @@ -136,4 +139,4 @@ The publish jobs will skip if the workflow was run manually. [GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment [trusted publishers]: https://docs.pypi.org/trusted-publishers/ [samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version -[manually]: #manually-running-the-release-workflow +[manually]: #manual-release diff --git a/docs/nodes/defaultNodes.md b/docs/nodes/defaultNodes.md index f62332da24..b78c9af901 100644 --- a/docs/nodes/defaultNodes.md +++ b/docs/nodes/defaultNodes.md @@ -19,6 +19,8 @@ their descriptions. | Conditioning Primitive | A conditioning tensor primitive value | | Content Shuffle Processor | Applies content shuffle processing to image | | ControlNet | Collects ControlNet info to pass to other nodes | +| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. | +| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. | | Denoise Latents | Denoises noisy latents to decodable images | | Divide Integers | Divides two numbers | | Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator | diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 50ebe5ce64..78ef965d5a 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -14,6 +14,7 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, @@ -32,6 +33,7 @@ from invokeai.backend.model_manager.config import ( ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata +from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata from invokeai.backend.model_manager.search import ModelSearch from ..dependencies import ApiDependencies @@ -243,6 +245,47 @@ async def get_model_metadata( return result +@model_manager_router.patch( + "/i/{key}/metadata", + operation_id="update_model_metadata", + responses={ + 201: { + "description": "The model metadata was updated successfully", + "content": {"application/json": {"example": example_model_metadata}}, + }, + 400: {"description": "Bad request"}, + }, +) +async def update_model_metadata( + key: str = Path(description="Key of the model repo metadata to fetch."), + changes: ModelMetadataChanges = Body(description="The changes"), +) -> Optional[AnyModelRepoMetadata]: + """Updates or creates a model metadata object.""" + record_store = ApiDependencies.invoker.services.model_manager.store + metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store + + try: + original_metadata = record_store.get_metadata(key) + if original_metadata: + if changes.default_settings: + original_metadata.default_settings = changes.default_settings + + metadata_store.update_metadata(key, original_metadata) + else: + metadata_store.add_metadata( + key, BaseMetadata(name="", author="", default_settings=changes.default_settings) + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the model metadata: {e}", + ) + + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) + + return result + + @model_manager_router.get( "/tags", operation_id="list_tags", @@ -451,6 +494,7 @@ async def add_model_record( ) async def install_model( source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), + inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False), # TODO(MM2): Can we type this? config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", @@ -493,6 +537,7 @@ async def install_model( source=source, config=config, access_token=access_token, + inplace=bool(inplace), ) logger.info(f"Started installation of {source}") except UnknownModelException as e: diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3243cc0fd3..4c766e955c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -181,6 +181,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ) +@invocation_output("gradient_mask_output") +class GradientMaskOutput(BaseInvocationOutput): + """Outputs a denoise mask and an image representing the total gradient of the mask.""" + + denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + expanded_mask_area: ImageField = OutputField( + description="Image representing the total gradient area of the mask. For paste-back purposes." + ) + + @invocation( "create_gradient_mask", title="Create Gradient Mask", @@ -201,38 +211,42 @@ class CreateGradientMaskInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + def invoke(self, context: InvocationContext) -> GradientMaskOutput: mask_image = context.images.get_pil(self.mask.image_name, mode="L") - if self.coherence_mode == "Box Blur": - blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) - else: # Gaussian Blur OR Staged - # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation - blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) + if self.edge_radius > 0: + if self.coherence_mode == "Box Blur": + blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) + else: # Gaussian Blur OR Staged + # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation + blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) - mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) - # redistribute blur so that the edges are 0 and blur out to 1 - blur_tensor = (blur_tensor - 0.5) * 2 + # redistribute blur so that the original edges are 0 and blur outwards to 1 + blur_tensor = (blur_tensor - 0.5) * 2 - threshold = 1 - self.minimum_denoise + threshold = 1 - self.minimum_denoise + + if self.coherence_mode == "Staged": + # wherever the blur_tensor is less than fully masked, convert it to threshold + blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor) + else: + # wherever the blur_tensor is above threshold but less than 1, drop it to threshold + blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) - if self.coherence_mode == "Staged": - # wherever the blur_tensor is masked to any degree, convert it to threshold - blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor) else: - # wherever the blur_tensor is above threshold but less than 1, drop it to threshold - blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) - - # multiply original mask to force actually masked regions to 0 - blur_tensor = mask_tensor * blur_tensor + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) - return DenoiseMaskOutput.build( - mask_name=mask_name, - masked_latents_name=None, - gradient=True, + # compute a [0, 1] mask from the blur_tensor + expanded_mask = torch.where((blur_tensor < 1), 0, 1) + expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L") + expanded_image_dto = context.images.save(expanded_mask_image) + + return GradientMaskOutput( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True), + expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name), ) @@ -518,7 +532,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def get_conditioning_data( self, context: InvocationContext, - unet, + unet: UNet2DConditionModel, latent_height: int, latent_width: int, ) -> TextConditioningData: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index f522282fee..b91f961099 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -7,7 +7,6 @@ import time from hashlib import sha256 from pathlib import Path from queue import Empty, Queue -from random import randbytes from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp from typing import Any, Dict, List, Optional, Set, Union @@ -21,6 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase +from invokeai.app.util.misc import uuid_string from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -150,7 +150,7 @@ class ModelInstallService(ModelInstallServiceBase): config = config or {} if not config.get("source"): config["source"] = model_path.resolve().as_posix() - config["key"] = config.get("key", self._create_key()) + config["key"] = config.get("key", uuid_string()) info: AnyModelConfig = self._probe_model(Path(model_path), config) @@ -178,13 +178,14 @@ class ModelInstallService(ModelInstallServiceBase): source: str, config: Optional[Dict[str, Any]] = None, access_token: Optional[str] = None, + inplace: bool = False, ) -> ModelInstallJob: variants = "|".join(ModelRepoVariant.__members__.values()) hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" source_obj: Optional[StringLikeSource] = None if Path(source).exists(): # A local file or directory - source_obj = LocalModelSource(path=Path(source)) + source_obj = LocalModelSource(path=Path(source), inplace=inplace) elif match := re.match(hf_repoid_re, source): source_obj = HFModelSource( repo_id=match.group(1), @@ -526,16 +527,17 @@ class ModelInstallService(ModelInstallServiceBase): setattr(info, key, value) return info - def _create_key(self) -> str: - return sha256(randbytes(100)).hexdigest()[0:32] - def _register( self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None ) -> str: # Note that we may be passed a pre-populated AnyModelConfig object, # in which case the key field should have been populated by the caller (e.g. in `install_path`). - config["key"] = config.get("key", self._create_key()) + config["key"] = config.get("key", uuid_string()) info = info or ModelProbe.probe(model_path, config) + override_key: Optional[str] = config.get("key") if config else None + + assert info.original_hash # always assigned by probe() + info.key = override_key or info.original_hash model_path = model_path.absolute() if model_path.is_relative_to(self.app_config.models_path): diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py index e0e4381b09..882575a4bf 100644 --- a/invokeai/app/services/model_metadata/metadata_store_base.py +++ b/invokeai/app/services/model_metadata/metadata_store_base.py @@ -4,9 +4,25 @@ Storage for Model Metadata """ from abc import ABC, abstractmethod -from typing import List, Set, Tuple +from typing import List, Optional, Set, Tuple +from pydantic import Field + +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata +from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings + + +class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"): + """A set of changes to apply to model metadata. + Only limited changes are valid: + - `default_settings`: the user-configured default settings for this model + """ + + default_settings: Optional[ModelDefaultSettings] = Field( + default=None, description="The user-configured default settings for this model" + ) + """The user-configured default settings for this model""" class ModelMetadataStoreBase(ABC): diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py index afe9d2c8c6..4f8170448f 100644 --- a/invokeai/app/services/model_metadata/metadata_store_sql.py +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -179,44 +179,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): ) return {x[0] for x in self._cursor.fetchall()} - def _update_tags(self, model_key: str, tags: Set[str]) -> None: + def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None: """Update tags for the model referenced by model_key.""" - # remove previous tags from this model - self._cursor.execute( - """--sql - DELETE FROM model_tags - WHERE model_id=?; - """, - (model_key,), - ) + if tags: + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE model_id=?; + """, + (model_key,), + ) - for tag in tags: - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO tags ( - tag_text - ) - VALUES (?); - """, - (tag,), - ) - self._cursor.execute( - """--sql - SELECT tag_id - FROM tags - WHERE tag_text = ? - LIMIT 1; - """, - (tag,), - ) - tag_id = self._cursor.fetchone()[0] - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO model_tags ( - model_id, - tag_id - ) - VALUES (?,?); - """, - (model_key, tag_id), - ) + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + model_id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c0b98220c8..a9039e2481 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.logger.error( f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" ) + self._invoker.services.logger.error(error) # Send error event self._invoker.services.events.emit_invocation_error( diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py b/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py index 2da998a532..be4d5f0140 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py @@ -3,7 +3,6 @@ import json import sqlite3 -from hashlib import sha1 from logging import Logger from pathlib import Path from typing import Optional @@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import ( ModelConfigFactory, ModelType, ) -from invokeai.backend.model_manager.hash import FastModelHash +from invokeai.backend.model_manager.hash import ModelHash ModelsValidator = TypeAdapter(AnyModelConfig) @@ -73,19 +72,27 @@ class MigrateModelYamlToDb1: base_type, model_type, model_name = str(model_key).split("/") try: - hash = FastModelHash.hash(self.config.models_path / stanza.path) + hash = ModelHash().hash(self.config.models_path / stanza.path) except OSError: self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") continue - assert isinstance(model_key, str) - new_key = sha1(model_key.encode("utf-8")).hexdigest() - stanza["base"] = BaseModelType(base_type) stanza["type"] = ModelType(model_type) stanza["name"] = model_name stanza["original_hash"] = hash stanza["current_hash"] = hash + new_key = hash # deterministic key assignment + + # special case for ip adapters, which need the new `image_encoder_model_id` field + if stanza["type"] == ModelType.IPAdapter: + try: + stanza["image_encoder_model_id"] = self._get_image_encoder_model_id( + self.config.models_path / stanza.path + ) + except OSError: + self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.") + continue new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094 @@ -95,7 +102,7 @@ class MigrateModelYamlToDb1: self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}") self._update_model(key, new_config) else: - self.logger.info(f"Adding model {model_name} with key {model_key}") + self.logger.info(f"Adding model {model_name} with key {new_key}") self._add_model(new_key, new_config) except DuplicateModelException: self.logger.warning(f"Model {model_name} is already in the database") @@ -149,3 +156,8 @@ class MigrateModelYamlToDb1: ) except sqlite3.IntegrityError as exc: raise DuplicateModelException(f"{record.name}: model is already in database") from exc + + def _get_image_encoder_model_id(self, model_path: Path) -> str: + with open(model_path / "image_encoder.txt") as f: + encoder = f.read() + return encoder.strip() diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index fb563a8cda..656b591f4a 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash import hashlib import os from pathlib import Path -from typing import Dict, Union +from typing import Callable, Literal, Optional, Union -from imohash import hashfile +from blake3 import blake3 + +MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") + +ALGORITHM = Literal[ + "md5", + "sha1", + "sha224", + "sha256", + "sha384", + "sha512", + "blake2b", + "blake2s", + "sha3_224", + "sha3_256", + "sha3_384", + "sha3_512", + "shake_128", + "shake_256", + "blake3", +] -class FastModelHash(object): - """FastModelHash obect provides one public class method, hash().""" +class ModelHash: + """ + Creates a hash of a model using a specified algorithm. - @classmethod - def hash(cls, model_location: Union[str, Path]) -> str: - """ - Return hexdigest string for model located at model_location. + Args: + algorithm: Hashing algorithm to use. Defaults to BLAKE3. + file_filter: A function that takes a file name and returns True if the file should be included in the hash. - :param model_location: Path to the model - """ - model_location = Path(model_location) - if model_location.is_file(): - return cls._hash_file(model_location) - elif model_location.is_dir(): - return cls._hash_dir(model_location) + If the model is a single file, it is hashed directly using the provided algorithm. + + If the model is a directory, each model weights file in the directory is hashed using the provided algorithm. + + Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth + + The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring + that directory hashes are never weaker than the file hashes. + + Usage: + ```py + # BLAKE3 hash + ModelHash().hash("path/to/some/model.safetensors") + # MD5 + ModelHash("md5").hash("path/to/model/dir/") + ``` + """ + + def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None: + if algorithm == "blake3": + self._hash_file = self._blake3 + elif algorithm in hashlib.algorithms_available: + self._hash_file = self._get_hashlib(algorithm) else: - raise OSError(f"Not a valid file or directory: {model_location}") + raise ValueError(f"Algorithm {algorithm} not available") - @classmethod - def _hash_file(cls, model_location: Union[str, Path]) -> str: + self._file_filter = file_filter or self._default_file_filter + + def hash(self, model_path: Union[str, Path]) -> str: """ - Fasthash a single file and return its hexdigest. + Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation. - :param model_location: Path to the model file + If model_path is a directory, the hash is computed by hashing the hashes of all model files in the + directory. The final composite hash is always computed using BLAKE3. + + Args: + model_path: Path to the model + + Returns: + str: Hexdigest of the hash of the model """ - # we return md5 hash of the filehash to make it shorter - # cryptographic security not needed here - return hashlib.md5(hashfile(model_location)).hexdigest() - @classmethod - def _hash_dir(cls, model_location: Union[str, Path]) -> str: - components: Dict[str, str] = {} + model_path = Path(model_path) + if model_path.is_file(): + return self._hash_file(model_path) + elif model_path.is_dir(): + return self._hash_dir(model_path) + else: + raise OSError(f"Not a valid file or directory: {model_path}") - for root, _dirs, files in os.walk(model_location): - for file in files: - # only tally tensor files because diffusers config files change slightly - # depending on how the model was downloaded/converted. - if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): - continue - path = (Path(root) / file).as_posix() - fast_hash = cls._hash_file(path) - components.update({path: fast_hash}) + def _hash_dir(self, dir: Path) -> str: + """Compute the hash for all files in a directory and return a hexdigest. - # hash all the model hashes together, using alphabetic file order - md5 = hashlib.md5() - for _path, fast_hash in sorted(components.items()): - md5.update(fast_hash.encode("utf-8")) - return md5.hexdigest() + Args: + dir: Path to the directory + + Returns: + str: Hexdigest of the hash of the directory + """ + model_component_paths = self._get_file_paths(dir, self._file_filter) + + component_hashes: list[str] = [] + for component in sorted(model_component_paths): + component_hashes.append(self._hash_file(component)) + + # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm + # for the composite hash + composite_hasher = blake3() + for h in component_hashes: + composite_hasher.update(h.encode("utf-8")) + return composite_hasher.hexdigest() + + @staticmethod + def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]: + """Return a list of all model files in the directory. + + Args: + model_path: Path to the model + file_filter: Function that takes a file name and returns True if the file should be included in the list. + + Returns: + List of all model files in the directory + """ + + files: list[Path] = [] + for root, _dirs, _files in os.walk(model_path): + for file in _files: + if file_filter(file): + files.append(Path(root, file)) + return files + + @staticmethod + def _blake3(file_path: Path) -> str: + """Hashes a file using BLAKE3 + + Args: + file_path: Path to the file to hash + + Returns: + Hexdigest of the hash of the file + """ + file_hasher = blake3(max_threads=blake3.AUTO) + file_hasher.update_mmap(file_path) + return file_hasher.hexdigest() + + @staticmethod + def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]: + """Factory function that returns a function to hash a file with the given algorithm. + + Args: + algorithm: Hashing algorithm to use + + Returns: + A function that hashes a file using the given algorithm + """ + + def hashlib_hasher(file_path: Path) -> str: + """Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory.""" + hasher = hashlib.new(algorithm) + buffer = bytearray(128 * 1024) + mv = memoryview(buffer) + with open(file_path, "rb", buffering=0) as f: + while n := f.readinto(mv): + hasher.update(mv[:n]) + return hasher.hexdigest() + + return hashlib_hasher + + @staticmethod + def _default_file_filter(file_path: str) -> bool: + """A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth + + Args: + file_path: Path to the file + + Returns: + True if the file matches the given extensions, otherwise False + """ + return file_path.endswith(MODEL_FILE_EXTENSIONS) diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 379369f9f5..5f062d0a04 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -25,6 +25,7 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session from typing_extensions import Annotated +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.backend.model_manager import ModelRepoVariant from ..util import select_hf_files @@ -68,12 +69,24 @@ class RemoteModelFile(BaseModel): sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) +class ModelDefaultSettings(BaseModel): + vae: str | None + vae_precision: str | None + scheduler: SCHEDULER_NAME_VALUES | None + steps: int | None + cfg_scale: float | None + cfg_rescale_multiplier: float | None + + class ModelMetadataBase(BaseModel): """Base class for model metadata information.""" name: str = Field(description="model's name") author: str = Field(description="model's author") - tags: Set[str] = Field(description="tags provided by model source") + tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None) + default_settings: Optional[ModelDefaultSettings] = Field( + description="default settings for this model", default=None + ) class BaseMetadata(ModelMetadataBase): diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 11b8f46951..a7250f33d1 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -21,7 +21,7 @@ from .config import ( ModelVariantType, SchedulerPredictionType, ) -from .hash import FastModelHash +from .hash import ModelHash from .util.model_util import lora_token_vector_length, read_checkpoint_meta CkptType = Dict[str, Any] @@ -147,7 +147,7 @@ class ModelProbe(object): if not probe_class: raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") - hash = FastModelHash.hash(model_path) + hash = ModelHash().hash(model_path) probe = probe_class(model_path) fields["path"] = model_path.as_posix() diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json index 65aa7b2a7a..23211c4e10 100644 --- a/invokeai/frontend/web/public/locales/de.json +++ b/invokeai/frontend/web/public/locales/de.json @@ -134,8 +134,6 @@ "loadMore": "Mehr laden", "noImagesInGallery": "Keine Bilder in der Galerie", "loading": "Lade", - "preparingDownload": "bereite Download vor", - "preparingDownloadFailed": "Problem beim Download vorbereiten", "deleteImage": "Lösche Bild", "copy": "Kopieren", "download": "Runterladen", @@ -967,7 +965,7 @@ "resumeFailed": "Problem beim Fortsetzen des Prozesses", "pruneFailed": "Problem beim leeren der Warteschlange", "pauseTooltip": "Prozess anhalten", - "back": "Hinten", + "back": "Ende", "resumeSucceeded": "Prozess wird fortgesetzt", "resumeTooltip": "Prozess wieder aufnehmen", "time": "Zeit", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 4065b0db86..406a33d9e8 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -78,6 +78,7 @@ "aboutDesc": "Using Invoke for work? Check out:", "aboutHeading": "Own Your Creative Power", "accept": "Accept", + "add": "Add", "advanced": "Advanced", "advancedOptions": "Advanced Options", "ai": "ai", @@ -734,6 +735,8 @@ "customConfig": "Custom Config", "customConfigFileLocation": "Custom Config File Location", "customSaveLocation": "Custom Save Location", + "defaultSettings": "Default Settings", + "defaultSettingsSaved": "Default Settings Saved", "delete": "Delete", "deleteConfig": "Delete Config", "deleteModel": "Delete Model", @@ -768,6 +771,7 @@ "mergedModelName": "Merged Model Name", "mergedModelSaveLocation": "Save Location", "mergeModels": "Merge Models", + "metadata": "Metadata", "model": "Model", "modelAdded": "Model Added", "modelConversionFailed": "Model Conversion Failed", @@ -839,9 +843,12 @@ "statusConverting": "Converting", "syncModels": "Sync Models", "syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.", + "triggerPhrases": "Trigger Phrases", + "typePhraseHere": "Type phrase here", "upcastAttention": "Upcast Attention", "updateModel": "Update Model", "useCustomConfig": "Use Custom Config", + "useDefaultSettings": "Use Default Settings", "v1": "v1", "v2_768": "v2 (768px)", "v2_base": "v2 (512px)", @@ -860,6 +867,7 @@ "models": { "addLora": "Add LoRA", "allLoRAsAdded": "All LoRAs added", + "concepts": "Concepts", "loraAlreadyAdded": "LoRA already added", "esrganModel": "ESRGAN Model", "loading": "loading", diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index f85cd89721..a4a5aeac90 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -505,8 +505,6 @@ "seamLowThreshold": "Bajo", "coherencePassHeader": "Parámetros de la coherencia", "compositingSettingsHeader": "Ajustes de la composición", - "coherenceSteps": "Pasos", - "coherenceStrength": "Fuerza", "patchmatchDownScaleSize": "Reducir a escala", "coherenceMode": "Modo" }, diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index 1a55f967f7..b3bd378783 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -114,7 +114,8 @@ "checkpoint": "Checkpoint", "safetensors": "Safetensors", "ai": "ia", - "file": "File" + "file": "File", + "toResolve": "Da risolvere" }, "gallery": { "generations": "Generazioni", @@ -142,8 +143,6 @@ "copy": "Copia", "download": "Scarica", "setCurrentImage": "Imposta come immagine corrente", - "preparingDownload": "Preparazione del download", - "preparingDownloadFailed": "Problema durante la preparazione del download", "downloadSelection": "Scarica gli elementi selezionati", "noImageSelected": "Nessuna immagine selezionata", "deleteSelection": "Elimina la selezione", @@ -609,8 +608,6 @@ "seamLowThreshold": "Basso", "seamHighThreshold": "Alto", "coherencePassHeader": "Passaggio di coerenza", - "coherenceSteps": "Passi", - "coherenceStrength": "Forza", "compositingSettingsHeader": "Impostazioni di composizione", "patchmatchDownScaleSize": "Ridimensiona", "coherenceMode": "Modalità", @@ -1400,19 +1397,6 @@ "Regola la maschera." ] }, - "compositingCoherenceSteps": { - "heading": "Passi", - "paragraphs": [ - "Numero di passi utilizzati nel Passaggio di Coerenza.", - "Simile ai passi di generazione." - ] - }, - "compositingBlur": { - "heading": "Sfocatura", - "paragraphs": [ - "Il raggio di sfocatura della maschera." - ] - }, "compositingCoherenceMode": { "heading": "Modalità", "paragraphs": [ @@ -1431,13 +1415,6 @@ "Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint." ] }, - "compositingStrength": { - "heading": "Forza", - "paragraphs": [ - "Quantità di rumore aggiunta per il Passaggio di Coerenza.", - "Simile alla forza di riduzione del rumore." - ] - }, "paramNegativeConditioning": { "paragraphs": [ "Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.", diff --git a/invokeai/frontend/web/public/locales/ko.json b/invokeai/frontend/web/public/locales/ko.json index 13f09d69ea..4cfb59f781 100644 --- a/invokeai/frontend/web/public/locales/ko.json +++ b/invokeai/frontend/web/public/locales/ko.json @@ -123,8 +123,6 @@ "autoSwitchNewImages": "새로운 이미지로 자동 전환", "loading": "불러오는 중", "unableToLoad": "갤러리를 로드할 수 없음", - "preparingDownload": "다운로드 준비", - "preparingDownloadFailed": "다운로드 준비 중 발생한 문제", "singleColumnLayout": "단일 열 레이아웃", "image": "이미지", "loadMore": "더 불러오기", diff --git a/invokeai/frontend/web/public/locales/nl.json b/invokeai/frontend/web/public/locales/nl.json index c23030bf54..9399dc4898 100644 --- a/invokeai/frontend/web/public/locales/nl.json +++ b/invokeai/frontend/web/public/locales/nl.json @@ -97,8 +97,6 @@ "featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.", "loading": "Bezig met laden", "unableToLoad": "Kan galerij niet laden", - "preparingDownload": "Bezig met voorbereiden van download", - "preparingDownloadFailed": "Fout bij voorbereiden van download", "downloadSelection": "Download selectie", "currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:", "copy": "Kopieer", @@ -535,8 +533,6 @@ "coherencePassHeader": "Coherentiestap", "maskBlur": "Vervaag", "maskBlurMethod": "Vervagingsmethode", - "coherenceSteps": "Stappen", - "coherenceStrength": "Sterkte", "seamHighThreshold": "Hoog", "seamLowThreshold": "Laag", "invoke": { @@ -1139,13 +1135,6 @@ "Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen." ] }, - "compositingCoherenceSteps": { - "heading": "Stappen", - "paragraphs": [ - "Het aantal te gebruiken ontruisingsstappen in de coherentiefase.", - "Gelijk aan de hoofdparameter Stappen." - ] - }, "dynamicPrompts": { "paragraphs": [ "Dynamische prompts vormt een enkele prompt om in vele.", @@ -1160,12 +1149,6 @@ ], "heading": "VAE" }, - "compositingBlur": { - "heading": "Vervaging", - "paragraphs": [ - "De vervagingsstraal van het masker." - ] - }, "paramIterations": { "paragraphs": [ "Het aantal te genereren afbeeldingen.", @@ -1240,13 +1223,6 @@ ], "heading": "Ontruisingssterkte" }, - "compositingStrength": { - "heading": "Sterkte", - "paragraphs": [ - "Ontruisingssterkte voor de coherentiefase.", - "Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding." - ] - }, "paramNegativeConditioning": { "paragraphs": [ "Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.", diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json index 8468554bab..00e64826e7 100644 --- a/invokeai/frontend/web/public/locales/ru.json +++ b/invokeai/frontend/web/public/locales/ru.json @@ -143,8 +143,6 @@ "problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений", "loading": "Загрузка", "unableToLoad": "Невозможно загрузить галерею", - "preparingDownload": "Подготовка к скачиванию", - "preparingDownloadFailed": "Проблема с подготовкой к скачиванию", "image": "изображение", "drop": "перебросить", "problemDeletingImages": "Проблема с удалением изображений", @@ -612,9 +610,7 @@ "maskBlurMethod": "Метод размытия", "seamLowThreshold": "Низкий", "seamHighThreshold": "Высокий", - "coherenceSteps": "Шагов", "coherencePassHeader": "Порог Coherence", - "coherenceStrength": "Сила", "compositingSettingsHeader": "Настройки компоновки", "invoke": { "noNodesInGraph": "Нет узлов в графе", @@ -1321,13 +1317,6 @@ "Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL." ] }, - "compositingCoherenceSteps": { - "heading": "Шаги", - "paragraphs": [ - "Количество шагов снижения шума, используемых при прохождении когерентности.", - "То же, что и основной параметр «Шаги»." - ] - }, "dynamicPrompts": { "paragraphs": [ "Динамические запросы превращают одно приглашение на множество.", @@ -1342,12 +1331,6 @@ ], "heading": "VAE" }, - "compositingBlur": { - "heading": "Размытие", - "paragraphs": [ - "Радиус размытия маски." - ] - }, "paramIterations": { "paragraphs": [ "Количество изображений, которые нужно сгенерировать.", @@ -1422,13 +1405,6 @@ ], "heading": "Шумоподавление" }, - "compositingStrength": { - "heading": "Сила", - "paragraphs": [ - null, - "То же, что параметр «Сила шумоподавления img2img»." - ] - }, "paramNegativeConditioning": { "paragraphs": [ "Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.", diff --git a/invokeai/frontend/web/public/locales/tr.json b/invokeai/frontend/web/public/locales/tr.json index 9fdbae0481..74465c15ed 100644 --- a/invokeai/frontend/web/public/locales/tr.json +++ b/invokeai/frontend/web/public/locales/tr.json @@ -355,7 +355,6 @@ "starImage": "Yıldız Koy", "download": "İndir", "deleteSelection": "Seçileni Sil", - "preparingDownloadFailed": "İndirme Hazırlanırken Sorun", "problemDeletingImages": "Görsel Silmede Sorun", "featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.", "galleryImageResetSize": "Boyutu Resetle", @@ -377,7 +376,6 @@ "setCurrentImage": "Çalışma Görseli Yap", "unableToLoad": "Galeri Yüklenemedi", "downloadSelection": "Seçileni İndir", - "preparingDownload": "İndirmeye Hazırlanıyor", "singleColumnLayout": "Tek Sütun Düzen", "generations": "Çıktılar", "showUploads": "Yüklenenleri Göster", @@ -723,7 +721,6 @@ "clipSkip": "CLIP Atlama", "randomizeSeed": "Rastgele Tohum", "cfgScale": "CFG Ölçeği", - "coherenceStrength": "Etki", "controlNetControlMode": "Yönetim Kipi", "general": "Genel", "img2imgStrength": "Görselden Görsel Ölçüsü", @@ -793,7 +790,6 @@ "cfgRescaleMultiplier": "CFG Rescale Çarpanı", "cfgRescale": "CFG Rescale", "coherencePassHeader": "Uyum Geçişi", - "coherenceSteps": "Adım", "infillMethod": "Doldurma Yöntemi", "maskBlurMethod": "Bulandırma Yöntemi", "steps": "Adım", diff --git a/invokeai/frontend/web/public/locales/zh_CN.json b/invokeai/frontend/web/public/locales/zh_CN.json index 3e4319fef8..673a2c4019 100644 --- a/invokeai/frontend/web/public/locales/zh_CN.json +++ b/invokeai/frontend/web/public/locales/zh_CN.json @@ -136,8 +136,6 @@ "copy": "复制", "download": "下载", "setCurrentImage": "设为当前图像", - "preparingDownload": "准备下载", - "preparingDownloadFailed": "准备下载时出现问题", "downloadSelection": "下载所选内容", "noImageSelected": "无选中的图像", "deleteSelection": "删除所选内容", @@ -616,11 +614,9 @@ "incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。" }, "patchmatchDownScaleSize": "缩小", - "coherenceSteps": "步数", "clipSkip": "CLIP 跳过层", "compositingSettingsHeader": "合成设置", "useCpuNoise": "使用 CPU 噪声", - "coherenceStrength": "强度", "enableNoiseSettings": "启用噪声设置", "coherenceMode": "模式", "cpuNoise": "CPU 噪声", @@ -1402,19 +1398,6 @@ "图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。" ] }, - "compositingCoherenceSteps": { - "heading": "步数", - "paragraphs": [ - "一致性层中使用的去噪步数。", - "与主参数中的步数相同。" - ] - }, - "compositingBlur": { - "heading": "模糊", - "paragraphs": [ - "遮罩模糊半径。" - ] - }, "noiseUseCPU": { "heading": "使用 CPU 噪声", "paragraphs": [ @@ -1467,13 +1450,6 @@ "第二轮去噪有助于合成内补/外扩图像。" ] }, - "compositingStrength": { - "heading": "强度", - "paragraphs": [ - "一致性层使用的去噪强度。", - "去噪强度与图生图的参数相同。" - ] - }, "paramNegativeConditioning": { "paragraphs": [ "生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index c5d86a127f..8e2715e3fa 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import type { AppDispatch, RootState } from 'app/store/store'; +import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings'; + export const listenerMiddleware = createListenerMiddleware(); export type AppStartListening = TypedStartListening; @@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening); // Dynamic prompts addDynamicPromptsListener(startAppListening); + +addSetDefaultSettingsListener(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts new file mode 100644 index 0000000000..cd4c574be4 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts @@ -0,0 +1,96 @@ +import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { setDefaultSettings } from 'features/parameters/store/actions'; +import { + setCfgRescaleMultiplier, + setCfgScale, + setScheduler, + setSteps, + vaePrecisionChanged, + vaeSelected, +} from 'features/parameters/store/generationSlice'; +import { + isParameterCFGRescaleMultiplier, + isParameterCFGScale, + isParameterPrecision, + isParameterScheduler, + isParameterSteps, + zParameterVAEModel, +} from 'features/parameters/types/parameterSchemas'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { t } from 'i18next'; +import { map } from 'lodash-es'; +import { modelsApi } from 'services/api/endpoints/models'; + +export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => { + startAppListening({ + actionCreator: setDefaultSettings, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const currentModel = state.generation.model; + + if (!currentModel) { + return; + } + + const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap(); + + if (!metadata || !metadata.default_settings) { + return; + } + + const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings; + + if (vae) { + // we store this as "default" within default settings + // to distinguish it from no default set + if (vae === 'default') { + dispatch(vaeSelected(null)); + } else { + const { data } = modelsApi.endpoints.getVaeModels.select()(state); + const vaeArray = map(data?.entities); + const validVae = vaeArray.find((model) => model.key === vae); + + const result = zParameterVAEModel.safeParse(validVae); + if (!result.success) { + return; + } + dispatch(vaeSelected(result.data)); + } + } + + if (vae_precision) { + if (isParameterPrecision(vae_precision)) { + dispatch(vaePrecisionChanged(vae_precision)); + } + } + + if (cfg_scale) { + if (isParameterCFGScale(cfg_scale)) { + dispatch(setCfgScale(cfg_scale)); + } + } + + if (cfg_rescale_multiplier) { + if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) { + dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); + } + } + + if (steps) { + if (isParameterSteps(steps)) { + dispatch(setSteps(steps)); + } + } + + if (scheduler) { + if (isParameterScheduler(scheduler)) { + dispatch(setScheduler(scheduler)); + } + } + + dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index a2b17b483d..0092d0c99e 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,4 +1,5 @@ import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; +import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { O } from 'ts-toolbelt'; @@ -82,6 +83,8 @@ export type AppConfig = { guidance: NumericalParameterConfig; cfgRescaleMultiplier: NumericalParameterConfig; img2imgStrength: NumericalParameterConfig; + scheduler?: ParameterScheduler; + vaePrecision?: ParameterPrecision; // Canvas boundingBoxHeight: NumericalParameterConfig; // initial value comes from model boundingBoxWidth: NumericalParameterConfig; // initial value comes from model diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index e7d40c5eaf..851d098763 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -59,7 +59,7 @@ const LoRASelect = () => { return ( - {t('models.lora')} + {t('models.concepts')} { const { t } = useTranslation(); - if (!status) { + if (!status || !Object.keys(STATUSES).includes(status)) { return <>; } diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx index 9cae8d2984..c19aceda11 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx @@ -8,7 +8,7 @@ export const ModelPane = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); return ( - {selectedModelKey ? : } + {selectedModelKey ? : } ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings.tsx new file mode 100644 index 0000000000..d45f33e390 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings.tsx @@ -0,0 +1,66 @@ +import { skipToken } from '@reduxjs/toolkit/query'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppSelector } from 'app/store/storeHooks'; +import Loading from 'common/components/Loading/Loading'; +import { selectConfigSlice } from 'features/system/store/configSlice'; +import { isNil } from 'lodash-es'; +import { useMemo } from 'react'; +import { useGetModelMetadataQuery } from 'services/api/endpoints/models'; + +import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm'; + +const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => { + const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd; + + return { + initialSteps: steps.initial, + initialCfg: guidance.initial, + initialScheduler: scheduler, + initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial, + initialVaePrecision: vaePrecision, + }; +}); + +export const DefaultSettings = () => { + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + + const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken); + const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } = + useAppSelector(initialStatesSelector); + + const defaultSettingsDefaults = useMemo(() => { + return { + vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' }, + vaePrecision: { + isEnabled: !isNil(data?.default_settings?.vae_precision), + value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32', + }, + scheduler: { + isEnabled: !isNil(data?.default_settings?.scheduler), + value: data?.default_settings?.scheduler || initialScheduler || 'euler', + }, + steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps }, + cfgScale: { + isEnabled: !isNil(data?.default_settings?.cfg_scale), + value: data?.default_settings?.cfg_scale || initialCfg, + }, + cfgRescaleMultiplier: { + isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier), + value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier, + }, + }; + }, [ + data?.default_settings, + initialSteps, + initialCfg, + initialScheduler, + initialCfgRescaleMultiplier, + initialVaePrecision, + ]); + + if (isLoading) { + return ; + } + + return ; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgRescaleMultiplier.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgRescaleMultiplier.tsx new file mode 100644 index 0000000000..fd88bab662 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgRescaleMultiplier.tsx @@ -0,0 +1,72 @@ +import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +import type { DefaultSettingsFormData } from './DefaultSettingsForm'; + +type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier']; + +export function DefaultCfgRescaleMultiplier(props: UseControllerProps) { + const { field } = useController(props); + + const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin); + const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax); + const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin); + const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax); + const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep); + const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep); + const { t } = useTranslation(); + const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]); + + const onChange = useCallback( + (v: number) => { + const updatedValue = { + ...(field.value as DefaultCfgRescaleMultiplierType), + value: v, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo(() => { + return (field.value as DefaultCfgRescaleMultiplierType).value; + }, [field.value]); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled; + }, [field.value]); + + return ( + + + {t('parameters.cfgRescaleMultiplier')} + + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgScale.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgScale.tsx new file mode 100644 index 0000000000..8e49517eb4 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgScale.tsx @@ -0,0 +1,72 @@ +import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +import type { DefaultSettingsFormData } from './DefaultSettingsForm'; + +type DefaultCfgType = DefaultSettingsFormData['cfgScale']; + +export function DefaultCfgScale(props: UseControllerProps) { + const { field } = useController(props); + + const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin); + const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax); + const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin); + const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax); + const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep); + const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep); + const { t } = useTranslation(); + const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]); + + const onChange = useCallback( + (v: number) => { + const updatedValue = { + ...(field.value as DefaultCfgType), + value: v, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo(() => { + return (field.value as DefaultCfgType).value; + }, [field.value]); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultCfgType).isEnabled; + }, [field.value]); + + return ( + + + {t('parameters.cfgScale')} + + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultScheduler.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultScheduler.tsx new file mode 100644 index 0000000000..46b42fd873 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultScheduler.tsx @@ -0,0 +1,50 @@ +import type { ComboboxOnChange } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants'; +import { isParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +import type { DefaultSettingsFormData } from './DefaultSettingsForm'; + +type DefaultSchedulerType = DefaultSettingsFormData['scheduler']; + +export function DefaultScheduler(props: UseControllerProps) { + const { t } = useTranslation(); + const { field } = useController(props); + + const onChange = useCallback( + (v) => { + if (!isParameterScheduler(v?.value)) { + return; + } + const updatedValue = { + ...(field.value as DefaultSchedulerType), + value: v.value, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo( + () => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value), + [field] + ); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultSchedulerType).isEnabled; + }, [field.value]); + + return ( + + + {t('parameters.scheduler')} + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSettingsForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSettingsForm.tsx new file mode 100644 index 0000000000..699e3e3445 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSettingsForm.tsx @@ -0,0 +1,147 @@ +import { Button, Flex, Heading } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { useCallback } from 'react'; +import type { SubmitHandler } from 'react-hook-form'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { IoPencil } from 'react-icons/io5'; +import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models'; + +import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier'; +import { DefaultCfgScale } from './DefaultCfgScale'; +import { DefaultScheduler } from './DefaultScheduler'; +import { DefaultSteps } from './DefaultSteps'; +import { DefaultVae } from './DefaultVae'; +import { DefaultVaePrecision } from './DefaultVaePrecision'; +import { SettingToggle } from './SettingToggle'; + +export interface FormField { + value: T; + isEnabled: boolean; +} + +export type DefaultSettingsFormData = { + vae: FormField; + vaePrecision: FormField; + scheduler: FormField; + steps: FormField; + cfgScale: FormField; + cfgRescaleMultiplier: FormField; +}; + +export const DefaultSettingsForm = ({ + defaultSettingsDefaults, +}: { + defaultSettingsDefaults: DefaultSettingsFormData; +}) => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + + const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation(); + + const { handleSubmit, control, formState } = useForm({ + defaultValues: defaultSettingsDefaults, + }); + + const onSubmit = useCallback>( + (data) => { + if (!selectedModelKey) { + return; + } + + const body = { + vae: data.vae.isEnabled ? data.vae.value : null, + vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null, + cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null, + cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null, + steps: data.steps.isEnabled ? data.steps.value : null, + scheduler: data.scheduler.isEnabled ? data.scheduler.value : null, + }; + + editModelMetadata({ + key: selectedModelKey, + body: { default_settings: body }, + }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('modelManager.defaultSettingsSaved'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }, + [selectedModelKey, dispatch, editModelMetadata, t] + ); + + return ( + <> + + {t('modelManager.defaultSettings')} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSteps.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSteps.tsx new file mode 100644 index 0000000000..4ccef8fd73 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSteps.tsx @@ -0,0 +1,72 @@ +import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +import type { DefaultSettingsFormData } from './DefaultSettingsForm'; + +type DefaultSteps = DefaultSettingsFormData['steps']; + +export function DefaultSteps(props: UseControllerProps) { + const { field } = useController(props); + + const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin); + const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax); + const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin); + const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax); + const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep); + const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep); + const { t } = useTranslation(); + const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]); + + const onChange = useCallback( + (v: number) => { + const updatedValue = { + ...(field.value as DefaultSteps), + value: v, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo(() => { + return (field.value as DefaultSteps).value; + }, [field.value]); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultSteps).isEnabled; + }, [field.value]); + + return ( + + + {t('parameters.steps')} + + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVae.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVae.tsx new file mode 100644 index 0000000000..b32f17dca1 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVae.tsx @@ -0,0 +1,65 @@ +import type { ComboboxOnChange } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { map } from 'lodash-es'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models'; + +import type { DefaultSettingsFormData } from './DefaultSettingsForm'; + +type DefaultVaeType = DefaultSettingsFormData['vae']; + +export function DefaultVae(props: UseControllerProps) { + const { t } = useTranslation(); + const { field } = useController(props); + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken); + + const { compatibleOptions } = useGetVaeModelsQuery(undefined, { + selectFromResult: ({ data }) => { + const modelArray = map(data?.entities); + const compatibleOptions = modelArray + .filter((vae) => vae.base === modelData?.base) + .map((vae) => ({ label: vae.name, value: vae.key })); + + const defaultOption = { label: 'Default VAE', value: 'default' }; + + return { compatibleOptions: [defaultOption, ...compatibleOptions] }; + }, + }); + + const onChange = useCallback( + (v) => { + const newValue = !v?.value ? 'default' : v.value; + + const updatedValue = { + ...(field.value as DefaultVaeType), + value: newValue, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo(() => { + return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value); + }, [compatibleOptions, field.value]); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultVaeType).isEnabled; + }, [field.value]); + + return ( + + + {t('modelManager.vae')} + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVaePrecision.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVaePrecision.tsx new file mode 100644 index 0000000000..240342b446 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVaePrecision.tsx @@ -0,0 +1,51 @@ +import type { ComboboxOnChange } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { isParameterPrecision } from 'features/parameters/types/parameterSchemas'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +import type { DefaultSettingsFormData } from './DefaultSettingsForm'; + +const options = [ + { label: 'FP16', value: 'fp16' }, + { label: 'FP32', value: 'fp32' }, +]; + +type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision']; + +export function DefaultVaePrecision(props: UseControllerProps) { + const { t } = useTranslation(); + const { field } = useController(props); + + const onChange = useCallback( + (v) => { + if (!isParameterPrecision(v?.value)) { + return; + } + const updatedValue = { + ...(field.value as DefaultVaePrecisionType), + value: v.value, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultVaePrecisionType).isEnabled; + }, [field.value]); + + return ( + + + {t('modelManager.vaePrecision')} + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle.tsx new file mode 100644 index 0000000000..bcea4959a8 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle.tsx @@ -0,0 +1,28 @@ +import { Switch } from '@invoke-ai/ui-library'; +import type { ChangeEvent } from 'react'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; + +import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm'; + +export function SettingToggle(props: UseControllerProps) { + const { field } = useController(props); + + const value = useMemo(() => { + return !!(field.value as FormField).isEnabled; + }, [field.value]); + + const onChange = useCallback( + (e: ChangeEvent) => { + const updatedValue: FormField = { + ...(field.value as FormField), + isEnabled: e.target.checked, + }; + field.onChange(updatedValue); + }, + [field] + ); + + return ; +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Metadata/ModelMetadata.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Metadata/ModelMetadata.tsx new file mode 100644 index 0000000000..7dc3c0bf62 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Metadata/ModelMetadata.tsx @@ -0,0 +1,18 @@ +import { Flex } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppSelector } from 'app/store/storeHooks'; +import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; +import { useGetModelMetadataQuery } from 'services/api/endpoints/models'; + +export const ModelMetadata = () => { + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken); + + return ( + <> + + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx index 6db804cccf..96e2629443 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx @@ -1,9 +1,58 @@ +import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; +import { useTranslation } from 'react-i18next'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; +import { ModelMetadata } from './Metadata/ModelMetadata'; +import { ModelAttrView } from './ModelAttrView'; import { ModelEdit } from './ModelEdit'; import { ModelView } from './ModelView'; export const Model = () => { + const { t } = useTranslation(); const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode); - return selectedModelMode === 'view' ? : ; + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); + + if (isLoading) { + return {t('common.loading')}; + } + + if (!data) { + return {t('common.somethingWentWrong')}; + } + + return ( + <> + + + {data.name} + + + {data.source && ( + + {t('modelManager.source')}: {data?.source} + + )} + + + + + + + + {t('modelManager.settings')} + {t('modelManager.metadata')} + + + + {selectedModelMode === 'view' ? : } + + + + + + + ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx index 2acbfe8b3e..0b25e5fdc7 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -1,12 +1,11 @@ -import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library'; +import { Box, Button, Flex, Text } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { IoPencil } from 'react-icons/io5'; -import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; import type { CheckpointModelConfig, ControlNetModelConfig, @@ -18,6 +17,7 @@ import type { VAEModelConfig, } from 'services/api/types'; +import { DefaultSettings } from './DefaultSettings'; import { ModelAttrView } from './ModelAttrView'; import { ModelConvert } from './ModelConvert'; @@ -26,7 +26,6 @@ export const ModelView = () => { const dispatch = useAppDispatch(); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); - const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken); const modelData = useMemo(() => { if (!data) { @@ -73,85 +72,56 @@ export const ModelView = () => { return {t('common.somethingWentWrong')}; } return ( - - - - - {modelData.name} - - - {modelData.source && ( - - {t('modelManager.source')}: {modelData.source} - - )} - - + + + + {modelData.type === 'main' && modelData.format === 'checkpoint' && } - - - - - - - - {t('modelManager.modelSettings')} - - - - - - - - - - - - {modelData.type === 'main' && ( - <> - - {modelData.format === 'diffusers' && ( - - )} - {modelData.format === 'checkpoint' && ( - - )} - - - - - - - - - - - - - )} - {modelData.type === 'ip_adapter' && ( + + + + + + + + + + {modelData.type === 'main' && ( + <> - - - )} - - - + {modelData.format === 'diffusers' && ( + + )} + {modelData.format === 'checkpoint' && ( + + )} - {metadata && ( - <> - - {t('modelManager.modelMetadata')} - - - - - - )} + + + + + + + + + + + + )} + {modelData.type === 'ip_adapter' && ( + + + + )} + + + + + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index 00bad63c3b..2672cf5be3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index 75f9a15f48..a9707e50f8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index fc60805e85..9f4e75de48 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index 44950ff40a..6c5a31926a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/NavigateToModelManagerButton.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/NavigateToModelManagerButton.tsx new file mode 100644 index 0000000000..733fb83826 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/NavigateToModelManagerButton.tsx @@ -0,0 +1,36 @@ +import type { IconButtonProps } from '@invoke-ai/ui-library'; +import { IconButton } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiGearSixBold } from 'react-icons/pi'; + +export const NavigateToModelManagerButton = memo((props: Omit) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const disabledTabs = useAppSelector((s) => s.config.disabledTabs); + const shouldShowButton = useMemo(() => !disabledTabs.includes('modelManager'), [disabledTabs]); + + const handleClick = useCallback(() => { + dispatch(setActiveTab('modelManager')); + }, [dispatch]); + + if (!shouldShowButton) { + return null; + } + + return ( + } + tooltip={t('modelManager.modelManager')} + aria-label={t('modelManager.modelManager')} + onClick={handleClick} + size="sm" + variant="ghost" + {...props} + /> + ); +}); + +NavigateToModelManagerButton.displayName = 'NavigateToModelManagerButton'; diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/UseDefaultSettingsButton.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/UseDefaultSettingsButton.tsx new file mode 100644 index 0000000000..7b322a3227 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/UseDefaultSettingsButton.tsx @@ -0,0 +1,28 @@ +import { IconButton } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { setDefaultSettings } from 'features/parameters/store/actions'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { RiSparklingFill } from 'react-icons/ri'; + +export const UseDefaultSettingsButton = () => { + const model = useAppSelector((s) => s.generation.model); + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const handleClickDefaultSettings = useCallback(() => { + dispatch(setDefaultSettings()); + }, [dispatch]); + + return ( + } + tooltip={t('modelManager.useDefaultSettings')} + aria-label={t('modelManager.useDefaultSettings')} + isDisabled={!model} + onClick={handleClickDefaultSettings} + size="sm" + variant="ghost" + /> + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index f7bf127c05..3b43129720 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types'; export const initialImageSelected = createAction('generation/initialImageSelected'); export const modelSelected = createAction('generation/modelSelected'); + +export const setDefaultSettings = createAction('generation/setDefaultSettings'); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 49ca507439..0f36d8b477 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -230,6 +230,12 @@ export const generationSlice = createSlice({ state.height = optimalDimension; } } + if (action.payload.sd?.scheduler) { + state.scheduler = action.payload.sd.scheduler; + } + if (action.payload.sd?.vaePrecision) { + state.vaePrecision = action.payload.sd.vaePrecision; + } }); // TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 26a55b7c70..ab2d5abed6 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -1,15 +1,5 @@ import type { FormLabelProps } from '@invoke-ai/ui-library'; -import { - Expander, - Flex, - FormControlGroup, - StandaloneAccordion, - Tab, - TabList, - TabPanel, - TabPanels, - Tabs, -} from '@invoke-ai/ui-library'; +import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library'; import { EMPTY_ARRAY } from 'app/store/constants'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; @@ -20,7 +10,9 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; import ParamSteps from 'features/parameters/components/Core/ParamSteps'; +import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton'; import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect'; +import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { filter } from 'lodash-es'; @@ -39,11 +31,11 @@ export const GenerationSettingsAccordion = memo(() => { () => createMemoizedSelector(selectLoraSlice, (lora) => { const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; - const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY; + const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY; const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY; return { loraTabBadges, accordionBadges }; }), - [modelConfig] + [modelConfig, t] ); const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges); const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({ @@ -58,39 +50,35 @@ export const GenerationSettingsAccordion = memo(() => { return ( - - - {t('accordions.generation.modelTab')} - {t('accordions.generation.conceptsTab')} - - - - - + + + + + + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + ); }); diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts index 4e1b734a66..76280df1ce 100644 --- a/invokeai/frontend/web/src/features/system/store/configSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts @@ -41,6 +41,8 @@ const initialConfigState: AppConfig = { boundingBoxHeight: { ...baseDimensionConfig }, scaledBoundingBoxWidth: { ...baseDimensionConfig }, scaledBoundingBoxHeight: { ...baseDimensionConfig }, + scheduler: 'euler', + vaePrecision: 'fp32', steps: { initial: 30, sliderMin: 1, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 04c65b59f6..dac6594255 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -24,7 +24,15 @@ export type UpdateModelArg = { body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; }; +type UpdateModelMetadataArg = { + key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key']; + body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json']; +}; + type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; +type UpdateModelMetadataResponse = + paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json']; + type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type GetModelMetadataResponse = @@ -172,6 +180,16 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), + updateModelMetadata: build.mutation({ + query: ({ key, body }) => { + return { + url: buildModelsUrl(`i/${key}/metadata`), + method: 'PATCH', + body: body, + }; + }, + invalidatesTags: ['Model'], + }), installModel: build.mutation({ query: ({ source, config, access_token }) => { return { @@ -351,6 +369,7 @@ export const { useGetModelMetadataQuery, useDeleteModelImportMutation, usePruneModelImportsMutation, + useUpdateModelMetadataMutation, } = modelsApi; const upsertModelConfigs = ( diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 12227d1ae9..560feb93ba 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -60,6 +60,11 @@ export type paths = { * @description Get a model metadata object. */ get: operations["get_model_metadata"]; + /** + * Update Model Metadata + * @description Updates or creates a model metadata object. + */ + patch: operations["update_model_metadata"]; }; "/api/v2/models/tags": { /** @@ -757,7 +762,14 @@ export type components = { * Tags * @description tags provided by model source */ - tags: string[]; + tags?: string[] | null; + /** + * Trigger Phrases + * @description trigger phrases for this model + */ + trigger_phrases?: string[] | null; + /** @description default settings for this model */ + default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Type * @default basemetadata @@ -1806,7 +1818,14 @@ export type components = { * Tags * @description tags provided by model source */ - tags: string[]; + tags?: string[] | null; + /** + * Trigger Phrases + * @description trigger phrases for this model + */ + trigger_phrases?: string[] | null; + /** @description default settings for this model */ + default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Files * @description model files and their sizes @@ -4264,7 +4283,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["ColorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["MergeTilesToImageInvocation"]; + [key: string]: components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"]; }; /** * Edges @@ -4301,7 +4320,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["FloatCollectionOutput"]; + [key: string]: components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["LatentsCollectionOutput"]; }; /** * Errors @@ -4424,7 +4443,14 @@ export type components = { * Tags * @description tags provided by model source */ - tags: string[]; + tags?: string[] | null; + /** + * Trigger Phrases + * @description trigger phrases for this model + */ + trigger_phrases?: string[] | null; + /** @description default settings for this model */ + default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Files * @description model files and their sizes @@ -7430,6 +7456,21 @@ export type components = { */ type: "mlsd_image_processor"; }; + /** ModelDefaultSettings */ + ModelDefaultSettings: { + /** Vae */ + vae: string | null; + /** Vae Precision */ + vae_precision: string | null; + /** Scheduler */ + scheduler: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm") | null; + /** Steps */ + steps: number | null; + /** Cfg Scale */ + cfg_scale: number | null; + /** Cfg Rescale Multiplier */ + cfg_rescale_multiplier: number | null; + }; /** * ModelFormat * @description Storage format of model. @@ -7556,6 +7597,24 @@ export type components = { */ unet: components["schemas"]["UNetField"]; }; + /** + * ModelMetadataChanges + * @description A set of changes to apply to model metadata. + * + * Only limited changes are valid: + * - `trigger_phrases`: the list of trigger phrases for this model + * - `default_settings`: the user-configured default settings for this model + */ + ModelMetadataChanges: { + /** + * Trigger Phrases + * @description The model's list of trigger phrases + */ + trigger_phrases?: string[] | null; + /** @description The user-configured default settings for this model */ + default_settings?: components["schemas"]["ModelDefaultSettings"] | null; + [key: string]: unknown; + }; /** * ModelRecordOrderBy * @description The order in which to return model summaries. @@ -11203,6 +11262,47 @@ export type operations = { }; }; }; + /** + * Update Model Metadata + * @description Updates or creates a model metadata object. + */ + update_model_metadata: { + parameters: { + path: { + /** @description Key of the model repo metadata to fetch. */ + key: string; + }; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelMetadataChanges"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null; + }; + }; + /** @description The model metadata was updated successfully */ + 201: { + content: { + "application/json": unknown; + }; + }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * List Tags * @description Get a unique set of all the model tags. diff --git a/pyproject.toml b/pyproject.toml index 26db5a63c7..d9f9634739 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,12 +51,12 @@ dependencies = [ "torchmetrics==0.11.4", "torchsde==0.2.6", "torchvision==0.16.2", - "transformers==4.37.2", + "transformers==4.38.2", # Core application dependencies, pinned for reproducible builds. "fastapi-events==0.10.1", "fastapi==0.109.2", - "huggingface-hub==0.20.3", + "huggingface-hub==0.21.3", "pydantic-settings==2.1.0", "pydantic==2.6.1", "python-socketio==5.11.1", @@ -64,6 +64,7 @@ dependencies = [ # Auxiliary dependencies, pinned only if necessary. "albumentations", + "blake3", "click", "datasets", "Deprecated", @@ -72,7 +73,6 @@ dependencies = [ "easing-functions", "einops", "facexlib", - "imohash", "matplotlib", # needed for plotting of Penner easing functions "npyscreen", "omegaconf", diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 00c463745c..4e146b44f9 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -3,6 +3,7 @@ Test the model installer """ import platform +import uuid from pathlib import Path import pytest @@ -30,9 +31,8 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 key = mm2_installer.register_path(embedding_file) - assert key is not None - assert key != "" - assert len(key) == 32 + # Not raising here is sufficient - key should be UUIDv4 + uuid.UUID(key, version=4) def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: diff --git a/tests/test_model_hash.py b/tests/test_model_hash.py new file mode 100644 index 0000000000..641a150034 --- /dev/null +++ b/tests/test_model_hash.py @@ -0,0 +1,96 @@ +# pyright:reportPrivateUsage=false + +from pathlib import Path +from typing import Iterable + +import pytest +from blake3 import blake3 + +from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash + +test_cases: list[tuple[ALGORITHM, str]] = [ + ("md5", "a0cd925fc063f98dbf029eee315060c3"), + ("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"), + ("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"), + ( + "sha512", + "c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6", + ), + ("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"), +] + + +@pytest.mark.parametrize("algorithm,expected_hash", test_cases) +def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str): + file = Path(tmp_path / "test") + file.write_text("model data") + md5 = ModelHash(algorithm).hash(file) + assert md5 == expected_hash + + +@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"]) +def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM): + model_hash = ModelHash(algorithm) + files = [Path(tmp_path, f"{i}.bin") for i in range(5)] + + for f in files: + f.write_text("data") + + md5 = model_hash.hash(tmp_path) + + # Manual implementation of composite hash - always uses BLAKE3 + composite_hasher = blake3() + for f in files: + h = model_hash.hash(f) + composite_hasher.update(h.encode("utf-8")) + + assert md5 == composite_hasher.hexdigest() + + +def test_model_hash_raises_error_on_invalid_algorithm(): + with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"): + ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType] + + +def paths_to_str_set(paths: Iterable[Path]) -> set[str]: + return {str(p) for p in paths} + + +def test_model_hash_filters_out_non_model_files(tmp_path: Path): + model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)} + + for i, f in enumerate(model_files): + f.write_text(f"data{i}") + + assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set( + model_files + ) + + # Add file that should be ignored - hash should not change + file = tmp_path / "test.icecream" + file.write_text("data") + + assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set( + model_files + ) + + # Add file that should not be ignored - hash should change + file = tmp_path / "test.bin" + file.write_text("more data") + model_files.add(file) + + assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set( + model_files + ) + + +def test_model_hash_uses_custom_filter(tmp_path: Path): + model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]} + + for i, f in enumerate(model_files): + f.write_text(f"data{i}") + + def file_filter(file_path: str) -> bool: + return file_path.endswith(".pickme") + + assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}