Compare commits

..

1 Commits

Author SHA1 Message Date
e2684b45af Add cProfile for profiling graph execution. 2024-01-12 10:58:03 -05:00
625 changed files with 12035 additions and 13205 deletions

59
.github/pr_labels.yml vendored
View File

@ -1,59 +0,0 @@
Root:
- changed-files:
- any-glob-to-any-file: '*'
PythonDeps:
- changed-files:
- any-glob-to-any-file: 'pyproject.toml'
Python:
- changed-files:
- all-globs-to-any-file:
- 'invokeai/**'
- '!invokeai/frontend/web/**'
PythonTests:
- changed-files:
- any-glob-to-any-file: 'tests/**'
CICD:
- changed-files:
- any-glob-to-any-file: .github/**
Docker:
- changed-files:
- any-glob-to-any-file: docker/**
Installer:
- changed-files:
- any-glob-to-any-file: installer/**
Documentation:
- changed-files:
- any-glob-to-any-file: docs/**
Invocations:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/invocations/**'
Backend:
- changed-files:
- any-glob-to-any-file: 'invokeai/backend/**'
Api:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/api/**'
Services:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/services/**'
FrontendDeps:
- changed-files:
- any-glob-to-any-file:
- '**/*/package.json'
- '**/*/pnpm-lock.yaml'
Frontend:
- changed-files:
- any-glob-to-any-file: 'invokeai/frontend/web/**'

View File

@ -1,16 +0,0 @@
name: "Pull Request Labeler"
on:
- pull_request_target
jobs:
labeler:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/labeler@v5
with:
configuration-path: .github/pr_labels.yml

View File

@ -58,7 +58,7 @@ jobs:
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
uses: tj-actions/changed-files@v37
with:
files_yaml: |
python:

View File

@ -15,13 +15,8 @@ model. These are the:
their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
are able to retrieve metadata from online model repositories,
transform them into Pydantic models, and cache them to the InvokeAI
SQL database.
* _DownloadQueueServiceBase_
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
@ -35,13 +30,13 @@ model. These are the:
## Location of the Code
The four main services can be found in
All four of these services can be found in
`invokeai/app/services` in the following directories:
* `invokeai/app/services/model_records/`
* `invokeai/app/services/model_install/`
* `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/` (**under development**)
* `invokeai/app/services/downloads/`(**under development**)
Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/model_records.py`.
@ -407,18 +402,15 @@ functionality:
the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in
`models_dir`.
`models_dir` (_implementation pending_).
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link
paste in a model page's URL or download link (_implementation pending_).
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
- Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information,
(currently only Civitai and HuggingFace).
variants such as fp16. (_implementation pending_)
### Initializing the installer
@ -434,24 +426,16 @@ following initialization pattern:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
installer = ModelInstallService(app_config=config,
record_store=record_store,
download_queue=queue
)
installer.start()
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(config, store)
```
The full form of `ModelInstallService()` takes the following
@ -459,12 +443,9 @@ required parameters:
| **Argument** | **Type** | **Description** |
|------------------|------------------------------|------------------------------|
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `record_store` | ModelRecordServiceBase | Config record storage database |
| `download_queue` | DownloadQueueServiceBase | Download queue object |
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
Once initialized, the installer will provide the following methods:
@ -493,14 +474,14 @@ source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', ac
for source in [source1, source2, source3, source4, source5, source6, source7]:
install_job = installer.install_model(source)
source2job = installer.wait_for_installs(timeout=120)
source2job = installer.wait_for_installs()
for source in sources:
job = source2job[source]
if job.complete:
if job.status == "completed":
model_config = job.config_out
model_key = model_config.key
print(f"{source} installed as {model_key}")
elif job.errored:
elif job.status == "error":
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
```
@ -534,117 +515,43 @@ The full list of arguments to `import_model()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | ModelSource | None | The source of the model, Path, URL or repo_id |
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `access_token` | str | None | Provide authorization information needed to download |
The next few sections describe the various types of ModelSource that
can be passed to `import_model()`.
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, a
copy of the model put into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`config` can be used to override all or a portion of the configuration
attributes returned by the model prober. See the section below for
details.
#### LocalModelSource
This is used for a model that is located on a locally-accessible Posix
filesystem, such as a local disk or networked fileshare.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `path` | str | Path | None | Path to the model file or directory |
| `inplace` | bool | False | If set, the model file(s) will be left in their location; otherwise they will be copied into the InvokeAI root's `models` directory |
#### URLModelSource
This is used for a single-file model that is accessible via a URL. The
fields are:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `url` | AnyHttpUrl | None | The URL for the model file. |
| `access_token` | str | None | An access token needed to gain access to this file. |
The `AnyHttpUrl` class can be imported from `pydantic.networks`.
Ordinarily, no metadata is retrieved from these sources. However,
there is special-case code in the installer that looks for HuggingFace
and Civitai URLs and fetches the corresponding model metadata from
the corresponding repo.
#### CivitaiModelSource
This is used for a model that is hosted by the Civitai web site.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `version_id` | int | None | The ID of the particular version of the desired model. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
Civitai has two model IDs, both of which are integers. The `model_id`
corresponds to a collection of model versions that may different in
arbitrary ways, such as derivation from different checkpoint training
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
`version_id` points to a specific version. Please use the latter.
Some Civitai models require an access token to download. These can be
generated from the Civitai profile page of a logged-in
account. Somewhat annoyingly, if you fail to provide the access token
when downloading a model that needs it, Civitai generates a redirect
to a login page rather than a 403 Forbidden error. The installer
attempts to catch this event and issue an informative error
message. Otherwise you will get an "unrecognized model suffix" error
when the model prober tries to identify the type of the HTML login
page.
#### HFModelSource
HuggingFace has the most complicated `ModelSource` structure:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `repo_id` | str | None | The ID of the desired model. |
| `variant` | ModelRepoVariant | ModelRepoVariant('fp16') | The desired variant. |
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
The `variant` is one of the various diffusers formats that HuggingFace
supports and is used to pick out from the hodgepodge of files that in
a typical HuggingFace repository the particular components needed for
a complete diffusers model. `ModelRepoVariant` is an enum that can be
imported from `invokeai.backend.model_manager` and has the following
values:
| **Name** | **String Value** |
|----------------------------|---------------------------|
| ModelRepoVariant.DEFAULT | "default" |
| ModelRepoVariant.FP16 | "fp16" |
| ModelRepoVariant.FP32 | "fp32" |
| ModelRepoVariant.ONNX | "onnx" |
| ModelRepoVariant.OPENVINO | "openvino" |
| ModelRepoVariant.FLAX | "flax" |
You can also pass the string forms to `variant` directly. Note that
InvokeAI may not be able to load and run all variants. At the current
time, specifying `ModelRepoVariant.DEFAULT` will retrieve model files
that are unqualified, e.g. `pytorch_model.safetensors` rather than
`pytorch_model.fp16.safetensors`. These are usually the 32-bit
safetensors forms of the model.
If `subfolder` is specified, then the requested model resides in a
subfolder of the main model repository. This is typically used to
fetch and install VAEs.
Some models require you to be registered with HuggingFace and logged
in. To download these files, you must provide an
`access_token`. Internally, if no access token is provided, then
`HfFolder.get_token()` will be called to fill it in with the cached
one.
`access_token` is passed to the download queue and used to access
repositories that require it.
#### Monitoring the install job process
@ -656,8 +563,7 @@ The `ModelInstallJob` class has the following structure:
| **Attribute** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `id` | `int` | Integer ID for this job |
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
| `config_in` | `dict` | Overriding configuration values provided by the caller |
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
@ -672,70 +578,30 @@ broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and
the following event names:
##### `model_install_downloading`
- `model_install_started`
For remote models only, `model_install_downloading` events will be issued at regular
intervals as the download progresses. The event's payload contains the
following keys:
The payload will contain the keys `timestamp` and `source`. The latter
indicates the requested model source for installation.
| **Key** | **Type** | **Description** |
|----------------|-----------|------------------|
| `source` | str | String representation of the requested source |
| `local_path` | str | String representation of the path to the downloading model (usually a temporary directory) |
| `bytes` | int | How many bytes downloaded so far |
| `total_bytes` | int | Total size of all the files that make up the model |
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
- `model_install_progress`
Emitted at regular intervals when downloading a remote model, the
payload will contain the keys `timestamp`, `source`, `current_bytes`
and `total_bytes`. These events are _not_ emitted when a local model
already on the filesystem is imported.
The parts is a list of dictionaries that give information on each of
the components pieces of the download. The dictionary's keys are
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
the like-named keys in the main event.
- `model_install_completed`
Note that downloading events will not be issued for local models, and
that downloading events occur *before* the running event.
Issued once at the end of a successful installation. The payload will
contain the keys `timestamp`, `source` and `key`, where `key` is the
ID under which the model has been registered.
##### `model_install_running`
`model_install_running` is issued when all the required downloads have completed (if applicable) and the
model probing, copying and registration process has now started.
The payload will contain the key `source`.
##### `model_install_completed`
`model_install_completed` is issued once at the end of a successful
installation. The payload will contain the keys `source`,
`total_bytes` and `key`, where `key` is the ID under which the model
has been registered.
##### `model_install_error`
`model_install_error` is emitted if the installation process fails for
some reason. The payload will contain the keys `source`, `error_type`
and `error`. `error_type` is a short message indicating the nature of
the error, and `error` is the long traceback to help debug the
problem.
##### `model_install_cancelled`
`model_install_cancelled` is issued if the model installation is
cancelled, or if one or more of its files' downloads are
cancelled. The payload will contain `source`.
##### Following the model status
You may poll the `ModelInstallJob` object returned by `import_model()`
to ascertain the state of the install. The job status can be read from
the job's `status` attribute, an `InstallStatus` enum which has the
enumerated values `WAITING`, `DOWNLOADING`, `RUNNING`, `COMPLETED`,
`ERROR` and `CANCELLED`.
For convenience, install jobs also provided the following boolean
properties: `waiting`, `downloading`, `running`, `complete`, `errored`
and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states.
- `model_install_error`
Emitted if the installation process fails for some reason. The payload
will contain the keys `timestamp`, `source`, `error_type` and
`error`. `error_type` is a short message indicating the nature of the
error, and `error` is the long traceback to help debug the problem.
#### Model confguration and probing
@ -755,9 +621,17 @@ overriding values for any of the model's configuration
attributes. Here is an example of setting the
`SchedulerPredictionType` and `name` for an sd-2 model:
This is typically used to set
the model's name and description, but can also be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.import_model(
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
config=dict(
prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model',
@ -769,38 +643,29 @@ install_job = installer.import_model(
This section describes additional methods provided by the installer class.
#### jobs = installer.wait_for_installs([timeout])
#### jobs = installer.wait_for_installs()
Block until all pending installs are completed or errored and then
returns a list of completed jobs. The optional `timeout` argument will
return from the call if jobs aren't completed in the specified
time. An argument of 0 (the default) will block indefinitely.
returns a list of completed jobs.
#### jobs = installer.list_jobs()
#### jobs = installer.list_jobs([source])
Return a list of all active and complete `ModelInstallJobs`.
Return a list of all active and complete `ModelInstallJobs`. An
optional `source` argument allows you to filter the returned list by a
model source string pattern using a partial string match.
#### jobs = installer.get_job_by_source(source)
#### jobs = installer.get_job(source)
Return a list of `ModelInstallJob` corresponding to the indicated
model source.
#### jobs = installer.get_job_by_id(id)
Return a list of `ModelInstallJob` corresponding to the indicated
model id.
#### jobs = installer.cancel_job(job)
Cancel the indicated job.
#### installer.prune_jobs
Remove jobs that are in a terminal state (i.e. complete, errored or
cancelled) from the job list returned by `list_jobs()` and
`get_job()`.
Remove non-pending jobs (completed or errored) from the job list
returned by `list_jobs()` and `get_job()`.
#### installer.app_config, installer.record_store, installer.event_bus
#### installer.app_config, installer.record_store,
installer.event_bus
Properties that provide access to the installer's `InvokeAIAppConfig`,
`ModelRecordServiceBase` and `EventServiceBase` objects.
@ -861,6 +726,120 @@ the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
## Let's get loaded, the lowdown on ModelLoadService
The `ModelLoadService` is responsible for loading a named model into
memory so that it can be used for inference. Despite the fact that it
does a lot under the covers, it is very straightforward to use.
An application-wide model loader is created at API initialization time
and stored in
`ApiDependencies.invoker.services.model_loader`. However, you can
create alternative instances if you wish.
### Creating a ModelLoadService object
The class is defined in
`invokeai.app.services.model_loader_service`. It is initialized with
an InvokeAIAppConfig object, from which it gets configuration
information such as the user's desired GPU and precision, and with a
previously-created `ModelRecordServiceBase` object, from which it
loads the requested model's configuration information.
Here is a typical initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_loader_service import ModelLoadService
config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config)
loader = ModelLoadService(config, store)
```
Note that we are relying on the contents of the application
configuration to choose the implementation of
`ModelRecordServiceBase`.
### get_model(key, [submodel_type], [context]) -> ModelInfo:
*** TO DO: change to get_model(key, context=None, **kwargs)
The `get_model()` method, like its similarly-named cousin in
`ModelRecordService`, receives the unique key that identifies the
model. It loads the model into memory, gets the model ready for use,
and returns a `ModelInfo` object.
The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and
is used to select which part of the main model to load.
The optional third argument, `context` can be provided by
an invocation to trigger model load event reporting. See below for
details.
The returned `ModelInfo` object shares some fields in common with
`ModelConfigBase`, but is otherwise a completely different beast:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `key` | str | The model key derived from the ModelRecordService database |
| `name` | str | Name of this model |
| `base_model` | BaseModelType | Base model for this model |
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
| `location` | Path or str | Location of the model on the filesystem |
| `precision` | torch.dtype | The torch.precision to use for inference |
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
The types for `ModelInfo` and `SubModelType` can be imported from
`invokeai.app.services.model_loader_service`.
To use the model, you use the `ModelInfo` as a context manager using
the following pattern:
```
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
image = vae.decode(latents)[0]
```
The `vae` model will stay locked in the GPU during the period of time
it is in the context manager's scope.
`get_model()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `InvalidModelException` -- the model is guilty of a variety of sins
** TO DO: ** Resolve discrepancy between ModelInfo.location and
ModelConfig.path.
### Emitting model loading events
When the `context` argument is passed to `get_model()`, it will
retrieve the invocation event bus from the passed `InvocationContext`
object to emit events on the invocation bus. The two events are
"model_load_started" and "model_load_completed". Both carry the
following payload:
```
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_key=model_key,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
)
```
***
## Get on line: The Download Queue
@ -900,6 +879,7 @@ following fields:
| `job_started` | float | | Timestamp for when the job started running |
| `job_ended` | float | | Timestamp for when the job completed or errored out |
| `job_sequence` | int | | A counter that is incremented each time a model is dequeued |
| `preserve_partial_downloads`| bool | False | Resume partial downloads when relaunched. |
| `error` | Exception | | A copy of the Exception that caused an error during download |
When you create a job, you can assign it a `priority`. If multiple
@ -1204,362 +1184,3 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state.
***
## This Meta be Good: Model Metadata Storage
The modules found under `invokeai.backend.model_manager.metadata`
provide a straightforward API for fetching model metadatda from online
repositories. Currently two repositories are supported: HuggingFace
and Civitai. However, the modules are easily extended for additional
repos, provided that they have defined APIs for metadata access.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage:
```
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
CivitaiMetadata
ModelMetadataStore,
)
# to access the initialized sql database
from invokeai.app.api.dependencies import ApiDependencies
civitai = CivitaiMetadataFetch()
# fetch the metadata
model_metadata = civitai.from_url("https://civitai.com/models/215796")
# get some common metadata fields
author = model_metadata.author
tags = model_metadata.tags
# get some Civitai-specific fields
assert isinstance(model_metadata, CivitaiMetadata)
trained_words = model_metadata.trained_words
base_model = model_metadata.base_model_trained_on
thumbnail = model_metadata.thumbnail_url
# cache the metadata to the database using the key corresponding to
# an existing model config record in the `model_config` table
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
# now we can search the database by tag, author or model name
# matches will contain a list of model keys that match the search
matches = sql_cache.search_by_tag({"tool", "turbo"})
```
### Structure of the Metadata objects
There is a short class hierarchy of Metadata objects, all of which
descend from the Pydantic `BaseModel`.
#### `ModelMetadataBase`
This is the common base class for metadata:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `name` | str | Repository's name for the model |
| `author` | str | Model's author |
| `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected
to be part of the business logic.
Descendents of the base add additional fields.
#### `HuggingFaceMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
| `id` | str | HuggingFace repo_id |
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
| `id` | int | Civitai model id |
| `version_name` | str | Name of this version of the model (distinct from model name) |
| `version_id` | int | Civitai model version id (distinct from model id) |
| `created` | datetime | Date this version of the model was created |
| `updated` | datetime | Date this version of the model was last updated |
| `published` | datetime | Date this version of the model was published to Civitai |
| `description` | str | Model description. Quite verbose and contains HTML tags |
| `version_description` | str | Model version description, usually describes changes to the model |
| `nsfw` | bool | Whether the model tends to generate NSFW content |
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
| `trained_words`| Set[str] | Trigger words for this model, if any |
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
| `base_model_trained_on` | str | Name of the model that this version was trained on |
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
Note that `weight_min` and `weight_max` are not currently populated
and take the default values of (-1.0, +2.0). The issue is that these
values aren't part of the structured data but appear in the text
description. Some regular expression or LLM coding may be able to
extract these values.
Also be aware that `base_model_trained_on` is free text and doesn't
correspond to our `ModelType` enum.
`CivitaiMetadata` also defines some convenience properties relating to
licensing restrictions: `credit_required`, `allow_commercial_use`,
`allow_derivatives` and `allow_different_license`.
#### `AnyModelRepoMetadata`
This is a discriminated Union of `CivitaiMetadata` and
`HuggingFaceMetadata`.
### Fetching Metadata from Online Repos
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
retrieve metadata from their corresponding repositories and return
`AnyModelRepoMetadata` objects. Their base class
`ModelMetadataFetchBase` is an abstract class that defines two
methods: `from_url()` and `from_id()`. The former accepts the type of
model URLs that the user will try to cut and paste into the model
import form. The latter accepts a string ID in the format recognized
by the repository of choice. Both methods return an
`AnyModelRepoMetadata`.
The base class also has a class method `from_json()` which will take
the JSON representation of a `ModelMetadata` object, validate it, and
return the corresponding `AnyModelRepoMetadata` object.
When initializing one of the metadata fetching classes, you may
provide a `requests.Session` argument. This allows you to customize
the low-level HTTP fetch requests and is used, for instance, in the
testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
`HuggingFaceMetadata` object directly.
#### CivitaiMetadataFetch
This adds the following methods:
`from_civitai_modelid()` This takes the ID of a model, finds the
default version of the model, and then retrieves the metadata for
that version, returning a `CivitaiMetadata` object directly.
`from_civitai_versionid()` This takes the ID of a model version and
retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model
metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
model key in the model_config table_. There is a foreign key integrity
constraint between the `model_config.id` field and the
`model_metadata.id` field such that if you attempt to save metadata
under an unknown key, the attempt will result in an
`UnknownModelException`. Likewise, when a model is deleted from
`model_config`, the deletion of the corresponding metadata record will
be triggered.
Tags are stored in a normalized fashion in the tables `model_tags` and
`tags`. Triggers keep the tag table in sync with the `model_metadata`
table.
To create the storage object, initialize it with the InvokeAI
`SqliteDatabase` object. This is often done this way:
```
from invokeai.app.api.dependencies import ApiDependencies
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
```
You can then access the storage with the following methods:
#### `add_metadata(key, metadata)`
Add the metadata using a previously-defined model key.
There is currently no `delete_metadata()` method. The metadata will
persist until the matching config is deleted from the `model_config`
table.
#### `get_metadata(key) -> AnyModelRepoMetadata`
Retrieve the metadata corresponding to the model key.
#### `update_metadata(key, new_metadata)`
Update an existing metadata record with new metadata.
#### `search_by_tag(tags: Set[str]) -> Set[str]`
Given a set of tags, find models that are tagged with them. If
multiple tags are provided then a matching model must be tagged with
*all* the tags in the set. This method returns a set of model keys and
is intended to be used in conjunction with the `ModelRecordService`:
```
model_config_store = ApiDependencies.invoker.services.model_records
matches = metadata_store.search_by_tag({'license:other'})
models = [model_config_store.get(x) for x in matches]
```
#### `search_by_name(name: str) -> Set[str]
Find all model metadata records that have the given name and return a
set of keys to the corresponding model config objects.
#### `search_by_author(author: str) -> Set[str]
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.
# The remainder of this documentation is provisional, pending implementation of the Load service
## Let's get loaded, the lowdown on ModelLoadService
The `ModelLoadService` is responsible for loading a named model into
memory so that it can be used for inference. Despite the fact that it
does a lot under the covers, it is very straightforward to use.
An application-wide model loader is created at API initialization time
and stored in
`ApiDependencies.invoker.services.model_loader`. However, you can
create alternative instances if you wish.
### Creating a ModelLoadService object
The class is defined in
`invokeai.app.services.model_loader_service`. It is initialized with
an InvokeAIAppConfig object, from which it gets configuration
information such as the user's desired GPU and precision, and with a
previously-created `ModelRecordServiceBase` object, from which it
loads the requested model's configuration information.
Here is a typical initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_loader_service import ModelLoadService
config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config)
loader = ModelLoadService(config, store)
```
Note that we are relying on the contents of the application
configuration to choose the implementation of
`ModelRecordServiceBase`.
### get_model(key, [submodel_type], [context]) -> ModelInfo:
*** TO DO: change to get_model(key, context=None, **kwargs)
The `get_model()` method, like its similarly-named cousin in
`ModelRecordService`, receives the unique key that identifies the
model. It loads the model into memory, gets the model ready for use,
and returns a `ModelInfo` object.
The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and
is used to select which part of the main model to load.
The optional third argument, `context` can be provided by
an invocation to trigger model load event reporting. See below for
details.
The returned `ModelInfo` object shares some fields in common with
`ModelConfigBase`, but is otherwise a completely different beast:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `key` | str | The model key derived from the ModelRecordService database |
| `name` | str | Name of this model |
| `base_model` | BaseModelType | Base model for this model |
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
| `location` | Path or str | Location of the model on the filesystem |
| `precision` | torch.dtype | The torch.precision to use for inference |
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
The types for `ModelInfo` and `SubModelType` can be imported from
`invokeai.app.services.model_loader_service`.
To use the model, you use the `ModelInfo` as a context manager using
the following pattern:
```
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
image = vae.decode(latents)[0]
```
The `vae` model will stay locked in the GPU during the period of time
it is in the context manager's scope.
`get_model()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `InvalidModelException` -- the model is guilty of a variety of sins
** TO DO: ** Resolve discrepancy between ModelInfo.location and
ModelConfig.path.
### Emitting model loading events
When the `context` argument is passed to `get_model()`, it will
retrieve the invocation event bus from the passed `InvocationContext`
object to emit events on the invocation bus. The two events are
"model_load_started" and "model_load_completed". Both carry the
following payload:
```
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_key=model_key,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
)
```

View File

@ -0,0 +1,76 @@
# Contributing to the Frontend
# InvokeAI Web UI
- [InvokeAI Web UI](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web/docs#invokeai-web-ui)
- [Stack](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web/docs#stack)
- [Contributing](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web/docs#contributing)
- [Dev Environment](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web/docs#dev-environment)
- [Production builds](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web/docs#production-builds)
The UI is a fairly straightforward Typescript React app, with the Unified Canvas being more complex.
Code is located in `invokeai/frontend/web/` for review.
## Stack
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). We lean heavily on RTK:
- `createAsyncThunk` for HTTP requests
- `createEntityAdapter` for fetching images and models
- `createListenerMiddleware` for workflows
The API client and associated types are generated from the OpenAPI schema. See API_CLIENT.md.
Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a simple socket.io redux middleware to help).
[Chakra-UI](https://github.com/chakra-ui/chakra-ui) & [Mantine](https://github.com/mantinedev/mantine) for components and styling.
[Konva](https://github.com/konvajs/react-konva) for the canvas, but we are pushing the limits of what is feasible with it (and HTML canvas in general). We plan to rebuild it with [PixiJS](https://github.com/pixijs/pixijs) to take advantage of WebGL's improved raster handling.
[Vite](https://vitejs.dev/) for bundling.
Localisation is via [i18next](https://github.com/i18next/react-i18next), but translation happens on our [Weblate](https://hosted.weblate.org/engage/invokeai/) project. Only the English source strings should be changed on this repo.
## Contributing
Thanks for your interest in contributing to the InvokeAI Web UI!
We encourage you to ping @psychedelicious and @blessedcoolant on [Discord](https://discord.gg/ZmtBAhwWhy) if you want to contribute, just to touch base and ensure your work doesn't conflict with anything else going on. The project is very active.
### Dev Environment
**Setup**
1. Install [node](https://nodejs.org/en/download/). You can confirm node is installed with:
```bash
node --version
```
2. Install [pnpm](https://pnpm.io/) and confirm it is installed by running this:
```bash
npm install --global pnpm
pnpm --version
```
From `invokeai/frontend/web/` run `pnpm install` to get everything set up.
Start everything in dev mode:
1. Ensure your virtual environment is running
2. Start the dev server: `pnpm dev`
3. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
4. Point your browser to the dev server address e.g. [http://localhost:5173/](http://localhost:5173/)
### VSCode Remote Dev
We've noticed an intermittent issue with the VSCode Remote Dev port forwarding. If you use this feature of VSCode, you may intermittently click the Invoke button and then get nothing until the request times out. Suggest disabling the IDE's port forwarding feature and doing it manually via SSH:
`ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host`
### Production builds
For a number of technical and logistical reasons, we need to commit UI build artefacts to the repo.
If you submit a PR, there is a good chance we will ask you to include a separate commit with a build of the app.
To build for production, run `pnpm build`.

View File

@ -12,7 +12,7 @@ To get started, take a look at our [new contributors checklist](newContributorCh
Once you're setup, for more information, you can review the documentation specific to your area of interest:
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
* #### [Frontend Documentation](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web)
* #### [Frontend Documentation](./contributingToFrontend.md)
* #### [Node Documentation](../INVOCATIONS.md)
* #### [Local Development](../LOCAL_DEVELOPMENT.md)

View File

@ -25,6 +25,7 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
+ [Grid to Gif](#grid-to-gif)
+ [Halftone](#halftone)
+ [Ideal Size](#ideal-size)
+ [Image and Mask Composition Pack](#image-and-mask-composition-pack)
+ [Image Dominant Color](#image-dominant-color)
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
@ -195,6 +196,13 @@ CMYK Halftone Output:
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/c59c578f-db8e-4d66-8c66-2851752d75ea" width="300" />
--------------------------------
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
--------------------------------
### Image and Mask Composition Pack

View File

@ -36,7 +36,6 @@ their descriptions.
| Integer Math | Perform basic math operations on two integers |
| Convert Image Mode | Converts an image to a different mode. |
| Crop Image | Crops an image to a specified box. The box can be outside of the image. |
| Ideal Size | Calculates an ideal image size for latents for a first pass of a multi-pass upscaling to avoid duplication and other artifacts |
| Image Hue Adjustment | Adjusts the Hue of an image. |
| Inverse Lerp Image | Inverse linear interpolation of all pixels of an image |
| Image Primitive | An image primitive value |

View File

@ -3,7 +3,6 @@
from logging import Logger
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -62,7 +61,7 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}")
@ -88,13 +87,8 @@ class ApiDependencies:
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events)
metadata_store = ModelMetadataStore(db=db)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=metadata_store,
event_bus=events,
app_config=config, record_store=model_record_service, event_bus=events
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
@ -137,6 +131,6 @@ class ApiDependencies:
db.clean()
@staticmethod
def shutdown() -> None:
def shutdown():
if ApiDependencies.invoker:
ApiDependencies.invoker.stop()

View File

@ -1,28 +0,0 @@
from typing import Any
from starlette.responses import Response
from starlette.staticfiles import StaticFiles
class NoCacheStaticFiles(StaticFiles):
"""
This class is used to override the default caching behavior of starlette for static files,
ensuring we *never* cache static files. It modifies the file response headers to strictly
never cache the files.
Static files include the javascript bundles, fonts, locales, and some images. Generated
images are not included, as they are served by a router.
"""
def __init__(self, *args: Any, **kwargs: Any):
self.cachecontrol = "max-age=0, no-cache, no-store, , must-revalidate"
self.pragma = "no-cache"
self.expires = "0"
super().__init__(*args, **kwargs)
def file_response(self, *args: Any, **kwargs: Any) -> Response:
resp = super().file_response(*args, **kwargs)
resp.headers.setdefault("Cache-Control", self.cachecontrol)
resp.headers.setdefault("Pragma", self.pragma)
resp.headers.setdefault("Expires", self.expires)
return resp

View File

@ -4,7 +4,7 @@
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -16,18 +16,13 @@ from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
@ -37,20 +32,11 @@ model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager
class ModelsList(BaseModel):
"""Return list of configs."""
models: List[AnyModelConfig]
models: list[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
@model_records_router.get(
"/",
operation_id="list_model_records",
@ -59,7 +45,7 @@ async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[ModelFormat] = Query(
model_format: Optional[str] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
) -> ModelsList:
@ -100,59 +86,6 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
@model_records_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records
result = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@model_records_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records
return record_store.list_tags()
@model_records_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
@ -226,7 +159,9 @@ async def del_model_record(
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
"""
Add a model using the configuration information appropriate for its type.
"""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
@ -308,7 +243,7 @@ async def import_model(
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
"model_install_running"
"model_install_started"
"model_install_completed"
"model_install_error"
@ -344,46 +279,16 @@ async def import_model(
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
"""
Return list of model install jobs.
If the optional 'source' argument is provided, then the list will be filtered
for partial string matches against the install source.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
return jobs
@model_records_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
200: {"description": "Success"},
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
201: {"description": "The job was cancelled successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.cancel_job(job)
@model_records_router.patch(
"/import",
operation_id="prune_model_install_jobs",
@ -393,7 +298,9 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
},
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
"""
Prune all completed and errored jobs from the install job list.
"""
ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204)
@ -408,9 +315,7 @@ async def prune_model_install_jobs() -> Response:
)
async def sync_models_to_config() -> Response:
"""
Traverse the models and autoimport directories.
Model files without a corresponding
Traverse the models and autoimport directories. Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()

View File

@ -3,7 +3,6 @@
# values from the command line or config file.
import sys
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
from .services.config import InvokeAIAppConfig
@ -28,7 +27,8 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
@ -221,13 +221,19 @@ def overridden_redoc() -> HTMLResponse:
web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here
# Only serve the UI if we it has a build
if (web_root_path / "dist").exists():
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
@app.get("/", include_in_schema=False, name="ui_root")
def get_index() -> FileResponse:
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
# Must mount *after* the other routes else it borks em
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
app.mount("/static", StaticFiles(directory=Path(web_root_path, "static/")), name="static") # docs favicon is in here
def invoke_api() -> None:

View File

@ -30,7 +30,6 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from ...backend.model_management import BaseModelType
from .baseinvocation import (
@ -603,33 +602,3 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map = Image.fromarray(color_map)
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
@invocation(
"depth_anything_image_processor",
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.0.0",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
if image.mode == "RGBA":
image = image.convert("RGB")
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
return processed_image

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
@ -1229,57 +1228,3 @@ class CropLatentsCoreInvocation(BaseInvocation):
context.services.latents.save(name, cropped_latents)
return build_latents_output(latents_name=name, latents=cropped_latents)
@invocation_output("ideal_size_output")
class IdealSizeOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
width: int = OutputField(description="The ideal width of the image (in pixels)")
height: int = OutputField(description="The ideal height of the image (in pixels)")
@invocation(
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.2",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
aspect = self.width / self.height
dimension = 512
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
dimension = 768
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = self.trim_to_multiple_of(
math.floor(init_width),
math.floor(init_height),
)
return IdealSizeOutput(width=scaled_width, height=scaled_height)

View File

@ -209,7 +209,7 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Configuration object for InvokeAI App."""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
singleton_init: ClassVar[Optional[Dict]] = None
# fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
@ -263,7 +263,7 @@ class InvokeAIAppConfig(InvokeAISettings):
# DEVICE
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
# GENERATION
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
@ -301,8 +301,8 @@ class InvokeAIAppConfig(InvokeAISettings):
self,
argv: Optional[list[str]] = None,
conf: Optional[DictConfig] = None,
clobber: Optional[bool] = False,
) -> None:
clobber=False,
):
"""
Update settings with contents of init file, environment, and command-line settings.
@ -337,7 +337,7 @@ class InvokeAIAppConfig(InvokeAISettings):
)
@classmethod
def get_config(cls, **kwargs: Any) -> InvokeAIAppConfig:
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
@ -455,7 +455,7 @@ class InvokeAIAppConfig(InvokeAISettings):
return _find_root()
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -34,7 +34,6 @@ class ServiceInactiveException(Exception):
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
@total_ordering
@ -56,7 +55,6 @@ class DownloadJob(BaseModel):
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
@ -72,11 +70,7 @@ class DownloadJob(BaseModel):
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
_on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None)
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
@ -93,26 +87,6 @@ class DownloadJob(BaseModel):
"""Call to cancel the job."""
return self._cancelled
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def running(self) -> bool:
"""Return true if the job is running."""
return self.status == DownloadJobStatus.RUNNING
@property
def errored(self) -> bool:
"""Return true if the job is errored."""
return self.status == DownloadJobStatus.ERROR
@property
def in_terminal_state(self) -> bool:
"""Return true if job has finished, one way or another."""
return self.status not in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]
@property
def on_start(self) -> Optional[DownloadEventHandler]:
"""Return the on_start event handler."""
@ -129,7 +103,7 @@ class DownloadJob(BaseModel):
return self._on_complete
@property
def on_error(self) -> Optional[DownloadExceptionHandler]:
def on_error(self) -> Optional[DownloadEventHandler]:
"""Return the on_error event handler."""
return self._on_error
@ -144,7 +118,7 @@ class DownloadJob(BaseModel):
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> None:
"""Set the callbacks for download events."""
self._on_start = on_start
@ -176,10 +150,10 @@ class DownloadQueueServiceBase(ABC):
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> DownloadJob:
"""
Create and enqueue download job.
Create a download job.
:param source: Source of the download as a URL.
:param dest: Path to download to. See below.
@ -201,25 +175,6 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def submit_download_job(
self,
job: DownloadJob,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""
Enqueue a download job.
:param job: The DownloadJob
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJob]:
"""
@ -242,21 +197,21 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def cancel_all_jobs(self) -> None:
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self) -> None:
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJob):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def join(self) -> None:
def join(self):
"""Wait until all jobs are off the queue."""
pass

View File

@ -5,9 +5,10 @@ import os
import re
import threading
import traceback
from logging import Logger
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
@ -20,7 +21,6 @@ from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
DownloadJobCancelledException,
DownloadJobStatus,
@ -36,6 +36,18 @@ DOWNLOAD_CHUNK_SIZE = 100000
class DownloadQueueService(DownloadQueueServiceBase):
"""Class for queued download of models."""
_jobs: Dict[int, DownloadJob]
_max_parallel_dl: int = 5
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue[DownloadJob]
_stop_event: threading.Event
_lock: threading.Lock
_logger: Logger
_events: Optional[EventServiceBase] = None
_next_job_id: int = 0
_accept_download_requests: bool = False
_requests: requests.sessions.Session
def __init__(
self,
max_parallel_dl: int = 5,
@ -87,33 +99,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._stop_event.set()
self._worker_pool.clear()
def submit_download_job(
self,
job: DownloadJob,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""Enqueue a download job."""
if not self._accept_download_requests:
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
def download(
self,
source: AnyHttpUrl,
@ -124,27 +109,32 @@ class DownloadQueueService(DownloadQueueServiceBase):
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> DownloadJob:
"""Create and enqueue a download job and return it."""
"""Create a download job and return its ID."""
if not self._accept_download_requests:
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
job = DownloadJob(
source=source,
dest=dest,
priority=priority,
access_token=access_token,
)
self.submit_download_job(
job,
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
with self._lock:
id = self._next_job_id
self._next_job_id += 1
job = DownloadJob(
id=id,
source=source,
dest=dest,
priority=priority,
access_token=access_token,
)
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[id] = job
self._queue.put(job)
return job
def join(self) -> None:
@ -160,7 +150,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
with self._lock:
to_delete = set()
for job_id, job in self._jobs.items():
if job.in_terminal_state:
if self._in_terminal_state(job):
to_delete.add(job_id)
for job_id in to_delete:
del self._jobs[job_id]
@ -182,12 +172,19 @@ class DownloadQueueService(DownloadQueueServiceBase):
with self._lock:
job.cancel()
def cancel_all_jobs(self) -> None:
def cancel_all_jobs(self, preserve_partial: bool = False) -> None:
"""Cancel all jobs (those not in enqueued, running or paused state)."""
for job in self._jobs.values():
if not job.in_terminal_state:
if not self._in_terminal_state(job):
self.cancel_job(job)
def _in_terminal_state(self, job: DownloadJob) -> bool:
return job.status in [
DownloadJobStatus.COMPLETED,
DownloadJobStatus.CANCELLED,
DownloadJobStatus.ERROR,
]
def _start_workers(self, max_workers: int) -> None:
"""Start the requested number of worker threads."""
self._stop_event.clear()
@ -217,7 +214,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
self._signal_job_error(job)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
@ -238,8 +235,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)
job.content_type = resp.headers.get("Content-Type")
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
@ -301,7 +296,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._signal_job_progress(job)
# if we get here we are done and can rename the file to the original dest
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
in_progress_path.rename(job.download_path)
def _validate_filename(self, directory: str, filename: str) -> bool:
@ -328,9 +322,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
@ -340,9 +332,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
@ -358,9 +348,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
@ -368,36 +356,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
self._logger.error(e)
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
def _signal_job_error(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
if job.on_error:
try:
job.on_error(job, excp)
job.on_error(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
self._logger.error(e)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.download_path}")
try:
if job.download_path:
partial_file = self._in_progress_path(job.download_path)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
@ -404,72 +404,53 @@ class EventServiceBase:
},
)
def emit_model_install_downloading(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
) -> None:
def emit_model_install_started(self, source: str) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
},
)
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
Emitted when an install job is started.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
event_name="model_install_started",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
def emit_model_install_completed(self, source: str, key: str) -> None:
"""
Emit when an install job is completed successfully.
Emitted when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
"total_bytes": total_bytes,
"key": key,
},
)
def emit_model_install_cancelled(self, source: str) -> None:
def emit_model_install_progress(
self,
source: str,
current_bytes: int,
total_bytes: int,
) -> None:
"""
Emit when an install job is cancelled.
Emitted while the install job is in progress.
(Downloaded models only)
:param source: Source of the model; local path, repo_id or url
:param source: Source of the model
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: Total bytes to download
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source},
event_name="model_install_progress",
payload={
"source": source,
"current_bytes": int,
"total_bytes": int,
},
)
def emit_model_install_error(
@ -479,11 +460,10 @@ class EventServiceBase:
error: str,
) -> None:
"""
Emit when an install job encounters an exception.
Emitted when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
:param exception: The exception that raised the error
"""
self.__emit_model_event(
event_name="model_install_error",

View File

@ -1,3 +1,4 @@
import cProfile
import time
import traceback
from threading import BoundedSemaphore, Event, Thread
@ -39,6 +40,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__threadLimit.acquire()
queue_item: Optional[InvocationQueueItem] = None
profiler = None
last_gesid = None
while not stop_event.is_set():
try:
queue_item = self.__invoker.services.queue.get()
@ -49,6 +53,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# do not hammer the queue
time.sleep(0.5)
continue
if last_gesid != queue_item.graph_execution_state_id:
if profiler is not None:
# I'm not sure what would cause us to get here, but if we do, we should restart the profiler for
# the new graph_execution_state_id.
profiler.disable()
logger.info(f"Stopped profiler for {last_gesid}.")
profiler = None
last_gesid = None
profiler = cProfile.Profile()
profiler.enable()
last_gesid = queue_item.graph_execution_state_id
logger.info(f"Started profiling {last_gesid}.")
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
@ -132,6 +151,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
source_node_id=source_node_id,
result=outputs.model_dump(),
)
self.__invoker.services.performance_statistics.log_stats()
except KeyboardInterrupt:
pass
@ -194,13 +214,19 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
if profiler is not None:
profiler.disable()
dump_path = f"{last_gesid}.prof"
profiler.dump_stats(dump_path)
logger.info(f"Saved profile to {dump_path}.")
profiler = None
last_gesid = None
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@ -30,13 +30,23 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from typing import Dict
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_common import NodeLog
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod
def __init__(self):
"""
@ -67,8 +77,45 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def log_stats(self, graph_execution_state_id: str):
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
pass

View File

@ -1,84 +1,25 @@
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict
# size of GIG in bytes
GIG = 1073741824
@dataclass
class NodeExecutionStats:
"""Class for tracking execution stats of an invocation node."""
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
invocation_type: str
start_time: float # Seconds since the epoch.
end_time: float # Seconds since the epoch.
start_ram_gb: float # GB
end_ram_gb: float # GB
peak_vram_gb: float # GB
def total_time(self) -> float:
return self.end_time - self.start_time
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
class GraphExecutionStats:
"""Class for tracking execution stats of a graph."""
@dataclass
class NodeLog:
"""Class for tracking node usage"""
def __init__(self):
self._node_stats_list: list[NodeExecutionStats] = []
def add_node_execution_stats(self, node_stats: NodeExecutionStats):
self._node_stats_list.append(node_stats)
def get_total_run_time(self) -> float:
"""Get the total time spent executing nodes in the graph."""
total = 0.0
for node_stats in self._node_stats_list:
total += node_stats.total_time()
return total
def get_first_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the first node in the graph (by start_time)."""
first_node = None
for node_stats in self._node_stats_list:
if first_node is None or node_stats.start_time < first_node.start_time:
first_node = node_stats
assert first_node is not None
return first_node
def get_last_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the last node in the graph (by end_time)."""
last_node = None
for node_stats in self._node_stats_list:
if last_node is None or node_stats.end_time > last_node.end_time:
last_node = node_stats
return last_node
def get_pretty_log(self, graph_execution_state_id: str) -> str:
log = f"Graph stats: {graph_execution_state_id}\n"
log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n"
# Log stats aggregated by node type.
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
for node_stats in self._node_stats_list:
node_stats_by_type[node_stats.invocation_type].append(node_stats)
for node_type, node_type_stats_list in node_stats_by_type.items():
num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n"
# Log stats for the entire graph.
log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n"
first_node = self.get_first_node_stats()
last_node = self.get_last_node_stats()
if first_node is not None and last_node is not None:
total_wall_time = last_node.end_time - first_node.start_time
ram_change = last_node.end_ram_gb - first_node.start_ram_gb
log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n"
log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n"
return log
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)

View File

@ -1,5 +1,5 @@
import time
from contextlib import contextmanager
from typing import Dict
import psutil
import torch
@ -7,119 +7,161 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import GraphExecutionStats, NodeExecutionStats
# Size of 1GB in bytes.
GB = 2**30
from .invocation_stats_common import GIG, NodeLog, NodeStats
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
_invoker: Invoker
def __init__(self):
# Maps graph_execution_state_id to GraphExecutionStats.
self._stats: dict[str, GraphExecutionStats] = {}
# Maps graph_execution_state_id to model manager CacheStats.
self._cache_stats: dict[str, CacheStats] = {}
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
self._stats[graph_execution_state_id] = GraphExecutionStats()
self._cache_stats[graph_execution_state_id] = CacheStats()
class StatsContext:
"""Context manager for collecting statistics."""
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
self._prune_stale_stats()
invocation: BaseInvocation
collector: "InvocationStatsServiceBase"
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerServiceBase
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if self._invoker.services.model_manager:
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerServiceBase,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager
try:
# Let the invocation run.
yield None
finally:
# Record state after the invocation.
node_stats = NodeExecutionStats(
invocation_type=invocation.type,
start_time=start_time,
end_time=time.time(),
start_ram_gb=start_ram / GB,
end_ram_gb=psutil.Process().memory_info().rss / GB,
peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0,
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def _prune_stale_stats(self):
"""Check all graphs being tracked and prune any that have completed/errored.
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
for now we call this function periodically to prevent them from accumulating.
"""
to_prune = []
for graph_execution_state_id in self._stats:
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
def reset_stats(self, graph_execution_id: str):
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
self.ram_used = ram_used
self.ram_changed = ram_changed
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
completed = set()
errored = set()
for graph_id, _node_log in self._stats.items():
try:
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception:
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
errored.add(graph_id)
continue
if not graph_execution_state.is_complete():
# The graph is still running, don't prune it.
if not current_graph_state.is_complete():
continue
to_prune.append(graph_execution_state_id)
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
total_time += stats.time_used
for graph_execution_state_id in to_prune:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
if len(to_prune) > 0:
logger.info(f"Pruned stale graph stats for {to_prune}.")
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available():
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
def reset_stats(self, graph_execution_state_id: str):
try:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
except KeyError as e:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.")
completed.add(graph_id)
def log_stats(self, graph_execution_state_id: str):
try:
graph_stats = self._stats[graph_execution_state_id]
cache_stats = self._cache_stats[graph_execution_state_id]
except KeyError as e:
logger.warning(f"Attempted to log statistics for unknown graph {graph_execution_state_id}: {e}.")
return
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]
log = graph_stats.get_pretty_log(graph_execution_state_id)
hwm = cache_stats.high_watermark / GB
tot = cache_stats.cache_size / GB
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB
log += f"RAM used to load models: {loaded:4.2f}G\n"
if torch.cuda.is_available():
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n"
log += "RAM cache statistics:\n"
log += f" Model cache hits: {cache_stats.hits}\n"
log += f" Model cache misses: {cache_stats.misses}\n"
log += f" Models cached: {cache_stats.in_cache}\n"
log += f" Models cleared from cache: {cache_stats.cleared}\n"
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
logger.info(log)
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
for graph_id in errored:
del self._stats[graph_id]
del self._cache_stats[graph_id]

View File

@ -1,7 +1,6 @@
"""Initialization file for model install service package."""
from .model_install_base import (
CivitaiModelSource,
HFModelSource,
InstallStatus,
LocalModelSource,
@ -23,5 +22,4 @@ __all__ = [
"LocalModelSource",
"HFModelSource",
"URLModelSource",
"CivitaiModelSource",
]

View File

@ -1,42 +1,27 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic import BaseModel, Field, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
from invokeai.backend.model_manager import AnyModelConfig
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
@ -89,31 +74,12 @@ class LocalModelSource(StringLikeSource):
return Path(self.path).as_posix()
class CivitaiModelSource(StringLikeSource):
"""A Civitai version id, with optional variant and access token."""
version_id: int
variant: Optional[ModelRepoVariant] = None
access_token: Optional[str] = None
type: Literal["civitai"] = "civitai"
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = str(self.version_id)
base += f" ({self.variant})" if self.variant else ""
return base
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
"""A HuggingFace repo_id, with optional variant and sub-folder."""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
variant: Optional[str] = None
subfolder: Optional[str | Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@ -137,22 +103,19 @@ class URLModelSource(StringLikeSource):
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
type: Literal["generic_url"] = "generic_url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
]
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
@ -165,74 +128,15 @@ class ModelInstallJob(BaseModel):
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: Optional[int] = Field(
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error_type = e.__class__.__name__
self.error = "".join(traceback.format_exception(e))
self.status = InstallStatus.ERROR
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
@property
def error(self) -> Optional[str]:
"""Error traceback."""
return "".join(traceback.format_exception(self._exception)) if self._exception else None
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@ -242,8 +146,6 @@ class ModelInstallServiceBase(ABC):
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStore,
event_bus: Optional["EventServiceBase"] = None,
):
"""
@ -254,14 +156,12 @@ class ModelInstallServiceBase(ABC):
:param event_bus: InvokeAI event bus for reporting events to.
"""
# make the invoker optional here because we don't need it and it
# makes the installer harder to use outside the web app
@abstractmethod
def start(self, invoker: Optional[Invoker] = None) -> None:
def start(self, *args: Any, **kwarg: Any) -> None:
"""Start the installer service."""
@abstractmethod
def stop(self, invoker: Optional[Invoker] = None) -> None:
def stop(self, *args: Any, **kwarg: Any) -> None:
"""Stop the model install service. After this the objection can be safely deleted."""
@property
@ -364,13 +264,9 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""
@abstractmethod
def get_job_by_id(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob corresponding to the provided id. Raises ValueError if no job has that ID."""
@abstractmethod
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
"""
@ -382,19 +278,16 @@ class ModelInstallServiceBase(ABC):
"""Prune all completed and errored jobs."""
@abstractmethod
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
@abstractmethod
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
def wait_for_installs(self) -> List[ModelInstallJob]:
"""
Wait for all pending installs to complete.
This will block until all pending installs have
completed, been cancelled, or errored out.
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
installs do not complete within the indicated time.
It will return the current list of jobs.
"""
@abstractmethod

View File

@ -1,72 +1,60 @@
"""Model installation class."""
import os
import re
import threading
import time
from hashlib import sha256
from logging import Logger
from pathlib import Path
from queue import Empty, Queue
from queue import Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelRepoVariant,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
ModelMetadataStore,
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import (
CivitaiModelSource,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
URLModelSource,
)
TMPDIR_PREFIX = "tmpinstall_"
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(
source=LocalModelSource(path="stop"),
local_path=Path("/dev/null"),
)
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
_app_config: InvokeAIAppConfig
_record_store: ModelRecordServiceBase
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue[ModelInstallJob]
_install_jobs: List[ModelInstallJob]
_logger: Logger
_cached_model_paths: Set[Path]
_models_installed: Set[str]
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: Optional[ModelMetadataStore] = None,
event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None,
):
"""
Initialize the installer object.
@ -79,26 +67,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._record_store = record_store
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs: List[ModelInstallJob] = []
self._install_queue: Queue[ModelInstallJob] = Queue()
self._cached_model_paths: Set[Path] = set()
self._models_installed: Set[str] = set()
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._downloads_changed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
self._session = session
self._next_job_id = 0
# There may not necessarily be a metadata store initialized
# so we create one and initialize it with the same sql database
# used by the record store service.
if metadata_store:
self._metadata_store = metadata_store
else:
assert isinstance(record_store, ModelRecordServiceSQL)
self._metadata_store = ModelMetadataStore(record_store.db)
self._install_jobs = []
self._install_queue = Queue()
self._cached_model_paths = set()
self._models_installed = set()
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -112,31 +84,69 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
# makes the installer harder to use outside the web app
def start(self, invoker: Optional[Invoker] = None) -> None:
def start(self, *args: Any, **kwarg: Any) -> None:
"""Start the installer thread."""
with self._lock:
if self._running:
raise Exception("Attempt to start the installer service twice")
self._start_installer_thread()
self._remove_dangling_install_dirs()
self.sync_to_config()
self._start_installer_thread()
self.sync_to_config()
def stop(self, invoker: Optional[Invoker] = None) -> None:
def stop(self, *args: Any, **kwarg: Any) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
with self._lock:
if not self._running:
raise Exception("Attempt to stop the install service before it was started")
self._stop_event.set()
with self._install_queue.mutex:
self._install_queue.queue.clear() # get rid of pending jobs
active_jobs = [x for x in self.list_jobs() if x.running]
if active_jobs:
self._logger.warning("Waiting for active install job to complete")
self.wait_for_installs()
self._download_cache.clear()
self._running = False
self._install_queue.put(STOP_JOB)
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
def _install_next_item(self) -> None:
done = False
while not done:
job = self._install_queue.get()
if job == STOP_JOB:
done = True
continue
assert job.local_path is not None
try:
self._signal_job_running(job)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
self._signal_job_errored(job, excp)
finally:
self._install_queue.task_done()
self._logger.info("Install thread exiting")
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"{job.source}: model installation started")
if self._event_bus:
self._event_bus.emit_model_install_started(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
)
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp)
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def register_path(
self,
@ -162,12 +172,7 @@ class ModelInstallService(ModelInstallServiceBase):
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
try:
new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
new_path = self._copy_model(model_path, dest_path)
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
@ -177,56 +182,43 @@ class ModelInstallService(ModelInstallServiceBase):
info,
)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install
elif isinstance(source, CivitaiModelSource):
install_job = self._import_from_civitai(source, config)
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
install_job = self._import_from_url(source, config)
else:
raise ValueError(f"Unsupported model source: '{type(source)}'")
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob: # noqa D102
if not config:
config = {}
self._install_jobs.append(install_job)
return install_job
# Installing a local path
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
job = ModelInstallJob(
source=source,
config_in=config,
local_path=Path(source.path),
)
self._install_jobs.append(job)
self._install_queue.put(job)
return job
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
raise UnknownModelException("File or directory not found")
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
return [x for x in self._install_jobs if x.source == source]
def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102
jobs = [x for x in self._install_jobs if x.id == id]
if not jobs:
raise ValueError(f"No job with id {id} known")
assert len(jobs) == 1
assert isinstance(jobs[0], ModelInstallJob)
return jobs[0]
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise Exception("Timeout exceeded")
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
self._install_queue.join()
return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
with self._lock:
self._cancel_download_parts(job)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
unfinished_jobs = [
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
]
self._install_jobs = unfinished_jobs
def sync_to_config(self) -> None:
@ -242,108 +234,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._models_installed.clear()
self._models_installed: Set[str] = set()
search.search(scan_dir)
return list(self._models_installed)
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
self._running = True
def _install_next_item(self) -> None:
done = False
while not done:
if self._stop_event.is_set():
done = True
continue
try:
job = self._install_queue.get(timeout=1)
except Empty:
continue
assert job.local_path is not None
try:
if job.cancelled:
self._signal_job_cancelled(job)
elif job.errored:
self._signal_job_errored(job)
elif (
job.waiting or job.downloading
): # local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
# enter the metadata, if there is any
if job.source_metadata:
self._metadata_store.add_metadata(key, job.source_metadata)
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
except (OSError, DuplicateModelException) as excp:
job.set_error(excp)
self._signal_job_errored(job)
finally:
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_queue.task_done()
self._logger.info("Install thread exiting")
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
# --------------------------------------------------------------------------------------------
def _remove_dangling_install_dirs(self) -> None:
"""Remove leftover tmpdirs from aborted installs."""
path = self._app_config.models_path
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
rmtree(tmpdir)
def _scan_models_directory(self) -> None:
"""
Scan the models directory for new and missing models.
@ -426,6 +320,28 @@ class ModelInstallService(ModelInstallServiceBase):
pass
return True
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
@ -481,279 +397,3 @@ class ModelInstallService(ModelInstallServiceBase):
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(key, info)
return key
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
@staticmethod
def _guess_variant() -> ModelRepoVariant:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
local_path=Path(source.path),
inplace=source.inplace,
)
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
if not source.access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
return self._import_remote_model(
source=source,
config=config,
remote_files=remote_files,
metadata=metadata,
)
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"https?://civitai.com/": CivitaiMetadataFetch,
r"https?://huggingface.co/": HuggingFaceMetadataFetch,
}
metadata = None
for pattern, fetcher in url_patterns.items():
if re.match(pattern, str(source.url), re.IGNORECASE):
metadata = fetcher(self._session).from_url(source.url)
break
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model(
source=source,
config=config,
metadata=metadata,
remote_files=remote_files,
)
def _import_remote_model(
self,
source: ModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
tmpdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
)
)
install_job = ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
# we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
self._logger.info(f"Queuing {source} for downloading")
for model_file in remote_files:
url = model_file.url
path = model_file.path
self._logger.info(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
access_token=source.access_token,
)
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
return install_job
def _stat_size(self, path: Path) -> int:
size = 0
if path.is_file():
size = path.stat().st_size
elif path.is_dir():
for root, _, files in os.walk(path):
size += sum(self._stat_size(Path(root, x)) for x in files)
return size
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download started")
with self._lock:
install_job = self._download_cache[download_job.source]
install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir:
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
dest_name = partial_path.parts[0]
install_job.local_path = install_job._install_tmpdir / dest_name
# Update the total bytes count for remote sources.
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
# are there any more active jobs left in this task?
if all(x.complete for x in install_job.download_parts):
# now enqueue job for actual installation into the models directory
self._install_queue.put(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
assert install_job is not None
assert excp is not None
install_job.set_error(excp)
self._logger.error(
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"Download {download_job.source} cancelled.")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._install_queue.put(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
# ------------------------------------------------------------------------------------------------
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"{job.source}: model installation started")
if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source))
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
)
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source))

View File

@ -4,8 +4,6 @@ from .model_records_base import ( # noqa F401
InvalidModelException,
ModelRecordServiceBase,
UnknownModelException,
ModelSummary,
ModelRecordOrderBy,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401
@ -15,6 +13,4 @@ __all__ = [
"DuplicateModelException",
"InvalidModelException",
"UnknownModelException",
"ModelSummary",
"ModelRecordOrderBy",
]

View File

@ -4,15 +4,10 @@ Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
class DuplicateModelException(Exception):
@ -31,33 +26,11 @@ class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelRecordOrderBy(str, Enum):
"""The order in which to return model summaries."""
Default = "default" # order by type, base, format and name
Type = "type"
Base = "base"
Name = "name"
Format = "format"
class ModelSummary(BaseModel):
"""A short summary of models for UI listing purposes."""
key: str = Field(description="model key")
type: ModelType = Field(description="model type")
base: BaseModelType = Field(description="base model")
format: ModelFormat = Field(description="model format")
name: str = Field(description="model name")
description: str = Field(description="short description of model")
tags: Set[str] = Field(description="tags associated with model")
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@abstractmethod
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
@ -81,7 +54,7 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
@ -102,47 +75,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
pass
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
pass
@abstractmethod
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
pass
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""

View File

@ -42,11 +42,9 @@ Typical usage:
import json
import sqlite3
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -54,14 +52,11 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
UnknownModelException,
)
@ -69,6 +64,9 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -80,12 +78,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._cursor = self._db.conn.cursor()
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
@ -300,95 +293,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
@property
def metadata_store(self) -> ModelMetadataStore:
"""Return a ModelMetadataStore initialized on the same database."""
return ModelMetadataStore(self._db)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
store = self.metadata_store
try:
metadata = store.get_metadata(key)
return metadata
except UnknownMetadataException:
return None
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStore(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStore(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStore(self._db)
return store.list_all_metadata()
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
ordering = {
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
ModelRecordOrderBy.Type: "a.type",
ModelRecordOrderBy.Base: "a.base",
ModelRecordOrderBy.Name: "a.name",
ModelRecordOrderBy.Format: "a.format",
}
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
"""Fix up results so that there are no null values."""
result: Dict[str, Union[str, int, Set[str]]] = {}
for key, item in summary.items():
result[key] = item or ""
result["tags"] = set(json.loads(summary["tags"] or "[]"))
return result
# Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock:
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from model_config;
""",
(),
)
total = int(self._cursor.fetchone()[0])
# query2: fetch key fields from the join of model_config and model_metadata
self._cursor.execute(
f"""--sql
SELECT a.id as key, a.type, a.base, a.format, a.name,
json_extract(a.config, '$.description') as description,
json_extract(b.metadata, '$.tags') as tags
FROM model_config AS a
LEFT JOIN model_metadata AS b on a.id=b.id
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)

View File

@ -6,7 +6,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -29,8 +28,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_3())
migrator.run_migrations()
return db

View File

@ -11,6 +11,8 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
UnsafeWorkflowWithVersionValidator,
)
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
@ -23,6 +25,8 @@ class Migration2Callback:
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
self._migrate_embedded_workflows(cursor)
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
@ -96,6 +100,45 @@ class Migration2Callback:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
model_record_migrator.migrate()
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in

View File

@ -1,16 +1,13 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration3Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __init__(self) -> None:
pass
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_manager_metadata(cursor)
@ -57,12 +54,11 @@ class Migration3Callback:
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
self._logger.info("Migrating model config records from models.yaml to database")
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
model_record_migrator = MigrateModelYamlToDb1(cursor)
model_record_migrator.migrate()
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
def build_migration_3() -> Migration:
"""
Build the migration from database version 2 to 3.
@ -73,7 +69,7 @@ def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migratio
migration_3 = Migration(
from_version=2,
to_version=3,
callback=Migration3Callback(app_config=app_config, logger=logger),
callback=Migration3Callback(),
)
return migration_3

View File

@ -1,83 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration4Callback:
"""Callback to do step 4 of migration."""
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
self._create_model_metadata(cursor)
self._create_model_tags(cursor)
self._create_tags(cursor)
self._create_triggers(cursor)
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store model metadata downloaded from remote sources."""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_metadata (
id TEXT NOT NULL PRIMARY KEY,
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
-- Serialized JSON representation of the whole metadata object,
-- which will contain additional fields from subclasses
metadata TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY(id) REFERENCES model_config(id) ON DELETE CASCADE
);
"""
)
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_tags (
model_id TEXT NOT NULL,
tag_id INTEGER NOT NULL,
FOREIGN KEY(model_id) REFERENCES model_config(id) ON DELETE CASCADE,
FOREIGN KEY(tag_id) REFERENCES tags(tag_id) ON DELETE CASCADE,
UNIQUE(model_id,tag_id)
);
"""
)
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
tag_id INTEGER NOT NULL PRIMARY KEY,
tag_text TEXT NOT NULL UNIQUE
);
"""
)
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
AFTER UPDATE
ON model_metadata FOR EACH ROW
BEGIN
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
def build_migration_4() -> Migration:
"""
Build the migration from database version 3 to 4.
Adds the tables needed to store model metadata and tags.
"""
migration_4 = Migration(
from_version=3,
to_version=4,
callback=Migration4Callback(),
)
return migration_4

View File

@ -23,6 +23,7 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
@ -45,9 +46,10 @@ class MigrateModelYamlToDb1:
logger: Logger
cursor: sqlite3.Cursor
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
self.config = config
self.logger = logger
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
self.cursor = cursor
def get_yaml(self) -> DictConfig:

View File

@ -1,4 +1,5 @@
{
"id": "6bfa0b3a-7090-4cd9-ad2d-a4b8662b6e71",
"name": "ESRGAN Upscaling with Canny ControlNet",
"author": "InvokeAI",
"description": "Sample workflow for using Upscaling with ControlNet with SD1.5",
@ -76,12 +77,12 @@
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 1250,
"y": 1500
},
"width": 320,
"height": 219
}
},
{
"id": "d8ace142-c05f-4f1d-8982-88dc7473958d",
@ -147,12 +148,12 @@
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 700,
"y": 1375
},
"width": 320,
"height": 193
}
},
{
"id": "771bdf6a-0813-4099-a5d8-921a138754d4",
@ -213,12 +214,12 @@
}
}
},
"width": 320,
"height": 225,
"position": {
"x": 375,
"y": 1900
},
"width": 320,
"height": 189
}
},
{
"id": "f7564dd2-9539-47f2-ac13-190804461f4e",
@ -314,12 +315,12 @@
}
}
},
"width": 320,
"height": 340,
"position": {
"x": 775,
"y": 1900
},
"width": 320,
"height": 295
}
},
{
"id": "1d887701-df21-4966-ae6e-a7d82307d7bd",
@ -415,12 +416,12 @@
}
}
},
"width": 320,
"height": 340,
"position": {
"x": 1200,
"y": 1900
},
"width": 320,
"height": 293
}
},
{
"id": "ca1d020c-89a8-4958-880a-016d28775cfa",
@ -433,7 +434,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.1.1",
"version": "1.1.0",
"nodePack": "invokeai",
"inputs": {
"image": {
@ -536,12 +537,12 @@
}
}
},
"width": 320,
"height": 511,
"position": {
"x": 1650,
"y": 1900
},
"width": 320,
"height": 451
}
},
{
"id": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
@ -639,12 +640,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 1650,
"y": 1775
},
"width": 320,
"height": 24
}
},
{
"id": "c3737554-8d87-48ff-a6f8-e71d2867f434",
@ -657,7 +658,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.1",
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
@ -865,12 +866,12 @@
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 2128.740065979906,
"y": 1232.6219060454753
},
"width": 320,
"height": 612
}
},
{
"id": "3ed9b2ef-f4ec-40a7-94db-92e63b583ec0",
@ -977,12 +978,12 @@
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 2559.4751127537957,
"y": 1246.6000376741406
},
"width": 320,
"height": 224
}
},
{
"id": "5ca498a4-c8c8-4580-a396-0c984317205d",
@ -1078,12 +1079,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 1650,
"y": 1675
},
"width": 320,
"height": 24
}
},
{
"id": "63b6ab7e-5b05-4d1b-a3b1-42d8e53ce16b",
@ -1136,12 +1137,12 @@
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 1250,
"y": 1200
},
"width": 320,
"height": 219
}
},
{
"id": "eb8f6f8a-c7b1-4914-806e-045ee2717a35",
@ -1194,168 +1195,168 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 1650,
"y": 1600
},
"width": 320,
"height": 24
}
}
],
"edges": [
{
"id": "5ca498a4-c8c8-4580-a396-0c984317205d-f50624ce-82bf-41d0-bdf7-8aab11a80d48-collapsed",
"type": "collapsed",
"source": "5ca498a4-c8c8-4580-a396-0c984317205d",
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48"
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
"type": "collapsed"
},
{
"id": "eb8f6f8a-c7b1-4914-806e-045ee2717a35-f50624ce-82bf-41d0-bdf7-8aab11a80d48-collapsed",
"type": "collapsed",
"source": "eb8f6f8a-c7b1-4914-806e-045ee2717a35",
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48"
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
"type": "collapsed"
},
{
"id": "reactflow__edge-771bdf6a-0813-4099-a5d8-921a138754d4image-f7564dd2-9539-47f2-ac13-190804461f4eimage",
"type": "default",
"source": "771bdf6a-0813-4099-a5d8-921a138754d4",
"target": "f7564dd2-9539-47f2-ac13-190804461f4e",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-f7564dd2-9539-47f2-ac13-190804461f4eimage-1d887701-df21-4966-ae6e-a7d82307d7bdimage",
"type": "default",
"source": "f7564dd2-9539-47f2-ac13-190804461f4e",
"target": "1d887701-df21-4966-ae6e-a7d82307d7bd",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-5ca498a4-c8c8-4580-a396-0c984317205dwidth-f50624ce-82bf-41d0-bdf7-8aab11a80d48width",
"type": "default",
"source": "5ca498a4-c8c8-4580-a396-0c984317205d",
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
"type": "default",
"sourceHandle": "width",
"targetHandle": "width"
},
{
"id": "reactflow__edge-5ca498a4-c8c8-4580-a396-0c984317205dheight-f50624ce-82bf-41d0-bdf7-8aab11a80d48height",
"type": "default",
"source": "5ca498a4-c8c8-4580-a396-0c984317205d",
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
"type": "default",
"sourceHandle": "height",
"targetHandle": "height"
},
{
"id": "reactflow__edge-f50624ce-82bf-41d0-bdf7-8aab11a80d48noise-c3737554-8d87-48ff-a6f8-e71d2867f434noise",
"type": "default",
"source": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
"target": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-5ca498a4-c8c8-4580-a396-0c984317205dlatents-c3737554-8d87-48ff-a6f8-e71d2867f434latents",
"type": "default",
"source": "5ca498a4-c8c8-4580-a396-0c984317205d",
"target": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-e8bf67fe-67de-4227-87eb-79e86afdfc74conditioning-c3737554-8d87-48ff-a6f8-e71d2867f434negative_conditioning",
"type": "default",
"source": "e8bf67fe-67de-4227-87eb-79e86afdfc74",
"target": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-63b6ab7e-5b05-4d1b-a3b1-42d8e53ce16bconditioning-c3737554-8d87-48ff-a6f8-e71d2867f434positive_conditioning",
"type": "default",
"source": "63b6ab7e-5b05-4d1b-a3b1-42d8e53ce16b",
"target": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-d8ace142-c05f-4f1d-8982-88dc7473958dclip-63b6ab7e-5b05-4d1b-a3b1-42d8e53ce16bclip",
"type": "default",
"source": "d8ace142-c05f-4f1d-8982-88dc7473958d",
"target": "63b6ab7e-5b05-4d1b-a3b1-42d8e53ce16b",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-d8ace142-c05f-4f1d-8982-88dc7473958dclip-e8bf67fe-67de-4227-87eb-79e86afdfc74clip",
"type": "default",
"source": "d8ace142-c05f-4f1d-8982-88dc7473958d",
"target": "e8bf67fe-67de-4227-87eb-79e86afdfc74",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-1d887701-df21-4966-ae6e-a7d82307d7bdimage-ca1d020c-89a8-4958-880a-016d28775cfaimage",
"type": "default",
"source": "1d887701-df21-4966-ae6e-a7d82307d7bd",
"target": "ca1d020c-89a8-4958-880a-016d28775cfa",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-ca1d020c-89a8-4958-880a-016d28775cfacontrol-c3737554-8d87-48ff-a6f8-e71d2867f434control",
"type": "default",
"source": "ca1d020c-89a8-4958-880a-016d28775cfa",
"target": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"type": "default",
"sourceHandle": "control",
"targetHandle": "control"
},
{
"id": "reactflow__edge-c3737554-8d87-48ff-a6f8-e71d2867f434latents-3ed9b2ef-f4ec-40a7-94db-92e63b583ec0latents",
"type": "default",
"source": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"target": "3ed9b2ef-f4ec-40a7-94db-92e63b583ec0",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-d8ace142-c05f-4f1d-8982-88dc7473958dvae-3ed9b2ef-f4ec-40a7-94db-92e63b583ec0vae",
"type": "default",
"source": "d8ace142-c05f-4f1d-8982-88dc7473958d",
"target": "3ed9b2ef-f4ec-40a7-94db-92e63b583ec0",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f7564dd2-9539-47f2-ac13-190804461f4eimage-5ca498a4-c8c8-4580-a396-0c984317205dimage",
"type": "default",
"source": "f7564dd2-9539-47f2-ac13-190804461f4e",
"target": "5ca498a4-c8c8-4580-a396-0c984317205d",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-d8ace142-c05f-4f1d-8982-88dc7473958dunet-c3737554-8d87-48ff-a6f8-e71d2867f434unet",
"type": "default",
"source": "d8ace142-c05f-4f1d-8982-88dc7473958d",
"target": "c3737554-8d87-48ff-a6f8-e71d2867f434",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-d8ace142-c05f-4f1d-8982-88dc7473958dvae-5ca498a4-c8c8-4580-a396-0c984317205dvae",
"type": "default",
"source": "d8ace142-c05f-4f1d-8982-88dc7473958d",
"target": "5ca498a4-c8c8-4580-a396-0c984317205d",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-eb8f6f8a-c7b1-4914-806e-045ee2717a35value-f50624ce-82bf-41d0-bdf7-8aab11a80d48seed",
"type": "default",
"source": "eb8f6f8a-c7b1-4914-806e-045ee2717a35",
"target": "f50624ce-82bf-41d0-bdf7-8aab11a80d48",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
}

View File

@ -1,4 +1,5 @@
{
"id": "1e385b84-86f8-452e-9697-9e5abed20518",
"name": "Multi ControlNet (Canny & Depth)",
"author": "InvokeAI",
"description": "A sample workflow using canny & depth ControlNets to guide the generation process. ",
@ -92,12 +93,12 @@
}
}
},
"width": 320,
"height": 225,
"position": {
"x": 3625,
"y": -75
},
"width": 320,
"height": 189
}
},
{
"id": "a33199c2-8340-401e-b8a2-42ffa875fc1c",
@ -110,7 +111,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.1.1",
"version": "1.1.0",
"nodePack": "invokeai",
"inputs": {
"image": {
@ -213,12 +214,12 @@
}
}
},
"width": 320,
"height": 511,
"position": {
"x": 4477.604342844504,
"y": -49.39005411272677
},
"width": 320,
"height": 451
}
},
{
"id": "273e3f96-49ea-4dc5-9d5b-9660390f14e1",
@ -271,12 +272,12 @@
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 4075,
"y": -825
},
"width": 320,
"height": 219
}
},
{
"id": "54486974-835b-4d81-8f82-05f9f32ce9e9",
@ -342,12 +343,12 @@
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 3600,
"y": -1000
},
"width": 320,
"height": 193
}
},
{
"id": "7ce68934-3419-42d4-ac70-82cfc9397306",
@ -400,12 +401,12 @@
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 4075,
"y": -1125
},
"width": 320,
"height": 219
}
},
{
"id": "d204d184-f209-4fae-a0a1-d152800844e1",
@ -418,7 +419,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.1.1",
"version": "1.1.0",
"nodePack": "invokeai",
"inputs": {
"image": {
@ -521,12 +522,12 @@
}
}
},
"width": 320,
"height": 511,
"position": {
"x": 4479.68542130465,
"y": -618.4221638099414
},
"width": 320,
"height": 451
}
},
{
"id": "c4b23e64-7986-40c4-9cad-46327b12e204",
@ -587,12 +588,12 @@
}
}
},
"width": 320,
"height": 225,
"position": {
"x": 3625,
"y": -425
},
"width": 320,
"height": 189
}
},
{
"id": "ca4d5059-8bfb-447f-b415-da0faba5a143",
@ -632,12 +633,12 @@
}
}
},
"width": 320,
"height": 104,
"position": {
"x": 4875,
"y": -575
},
"width": 320,
"height": 87
}
},
{
"id": "018b1214-c2af-43a7-9910-fb687c6726d7",
@ -733,12 +734,12 @@
}
}
},
"width": 320,
"height": 340,
"position": {
"x": 4100,
"y": -75
},
"width": 320,
"height": 293
}
},
{
"id": "c826ba5e-9676-4475-b260-07b85e88753c",
@ -834,12 +835,12 @@
}
}
},
"width": 320,
"height": 340,
"position": {
"x": 4095.757337055795,
"y": -455.63440891935863
},
"width": 320,
"height": 293
}
},
{
"id": "9db25398-c869-4a63-8815-c6559341ef12",
@ -946,12 +947,12 @@
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 5675,
"y": -825
},
"width": 320,
"height": 224
}
},
{
"id": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
@ -964,7 +965,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.1",
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
@ -1172,12 +1173,12 @@
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 5274.672987098195,
"y": -823.0752416664332
},
"width": 320,
"height": 612
}
},
{
"id": "2e77a0a1-db6a-47a2-a8bf-1e003be6423b",
@ -1274,12 +1275,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 4875,
"y": -675
},
"width": 320,
"height": 24
}
},
{
"id": "8b260b4d-3fd6-44d4-b1be-9f0e43c628ce",
@ -1332,146 +1333,146 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 4875,
"y": -750
},
"width": 320,
"height": 24
}
}
],
"edges": [
{
"id": "8b260b4d-3fd6-44d4-b1be-9f0e43c628ce-2e77a0a1-db6a-47a2-a8bf-1e003be6423b-collapsed",
"type": "collapsed",
"source": "8b260b4d-3fd6-44d4-b1be-9f0e43c628ce",
"target": "2e77a0a1-db6a-47a2-a8bf-1e003be6423b"
"target": "2e77a0a1-db6a-47a2-a8bf-1e003be6423b",
"type": "collapsed"
},
{
"id": "reactflow__edge-54486974-835b-4d81-8f82-05f9f32ce9e9clip-7ce68934-3419-42d4-ac70-82cfc9397306clip",
"type": "default",
"source": "54486974-835b-4d81-8f82-05f9f32ce9e9",
"target": "7ce68934-3419-42d4-ac70-82cfc9397306",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-54486974-835b-4d81-8f82-05f9f32ce9e9clip-273e3f96-49ea-4dc5-9d5b-9660390f14e1clip",
"type": "default",
"source": "54486974-835b-4d81-8f82-05f9f32ce9e9",
"target": "273e3f96-49ea-4dc5-9d5b-9660390f14e1",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-a33199c2-8340-401e-b8a2-42ffa875fc1ccontrol-ca4d5059-8bfb-447f-b415-da0faba5a143item",
"type": "default",
"source": "a33199c2-8340-401e-b8a2-42ffa875fc1c",
"target": "ca4d5059-8bfb-447f-b415-da0faba5a143",
"type": "default",
"sourceHandle": "control",
"targetHandle": "item"
},
{
"id": "reactflow__edge-d204d184-f209-4fae-a0a1-d152800844e1control-ca4d5059-8bfb-447f-b415-da0faba5a143item",
"type": "default",
"source": "d204d184-f209-4fae-a0a1-d152800844e1",
"target": "ca4d5059-8bfb-447f-b415-da0faba5a143",
"type": "default",
"sourceHandle": "control",
"targetHandle": "item"
},
{
"id": "reactflow__edge-8e860e51-5045-456e-bf04-9a62a2a5c49eimage-018b1214-c2af-43a7-9910-fb687c6726d7image",
"type": "default",
"source": "8e860e51-5045-456e-bf04-9a62a2a5c49e",
"target": "018b1214-c2af-43a7-9910-fb687c6726d7",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-018b1214-c2af-43a7-9910-fb687c6726d7image-a33199c2-8340-401e-b8a2-42ffa875fc1cimage",
"type": "default",
"source": "018b1214-c2af-43a7-9910-fb687c6726d7",
"target": "a33199c2-8340-401e-b8a2-42ffa875fc1c",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-c4b23e64-7986-40c4-9cad-46327b12e204image-c826ba5e-9676-4475-b260-07b85e88753cimage",
"type": "default",
"source": "c4b23e64-7986-40c4-9cad-46327b12e204",
"target": "c826ba5e-9676-4475-b260-07b85e88753c",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-c826ba5e-9676-4475-b260-07b85e88753cimage-d204d184-f209-4fae-a0a1-d152800844e1image",
"type": "default",
"source": "c826ba5e-9676-4475-b260-07b85e88753c",
"target": "d204d184-f209-4fae-a0a1-d152800844e1",
"type": "default",
"sourceHandle": "image",
"targetHandle": "image"
},
{
"id": "reactflow__edge-54486974-835b-4d81-8f82-05f9f32ce9e9vae-9db25398-c869-4a63-8815-c6559341ef12vae",
"type": "default",
"source": "54486974-835b-4d81-8f82-05f9f32ce9e9",
"target": "9db25398-c869-4a63-8815-c6559341ef12",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-ac481b7f-08bf-4a9d-9e0c-3a82ea5243celatents-9db25398-c869-4a63-8815-c6559341ef12latents",
"type": "default",
"source": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
"target": "9db25398-c869-4a63-8815-c6559341ef12",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-ca4d5059-8bfb-447f-b415-da0faba5a143collection-ac481b7f-08bf-4a9d-9e0c-3a82ea5243cecontrol",
"type": "default",
"source": "ca4d5059-8bfb-447f-b415-da0faba5a143",
"target": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
"type": "default",
"sourceHandle": "collection",
"targetHandle": "control"
},
{
"id": "reactflow__edge-54486974-835b-4d81-8f82-05f9f32ce9e9unet-ac481b7f-08bf-4a9d-9e0c-3a82ea5243ceunet",
"type": "default",
"source": "54486974-835b-4d81-8f82-05f9f32ce9e9",
"target": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-273e3f96-49ea-4dc5-9d5b-9660390f14e1conditioning-ac481b7f-08bf-4a9d-9e0c-3a82ea5243cenegative_conditioning",
"type": "default",
"source": "273e3f96-49ea-4dc5-9d5b-9660390f14e1",
"target": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-7ce68934-3419-42d4-ac70-82cfc9397306conditioning-ac481b7f-08bf-4a9d-9e0c-3a82ea5243cepositive_conditioning",
"type": "default",
"source": "7ce68934-3419-42d4-ac70-82cfc9397306",
"target": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-2e77a0a1-db6a-47a2-a8bf-1e003be6423bnoise-ac481b7f-08bf-4a9d-9e0c-3a82ea5243cenoise",
"type": "default",
"source": "2e77a0a1-db6a-47a2-a8bf-1e003be6423b",
"target": "ac481b7f-08bf-4a9d-9e0c-3a82ea5243ce",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-8b260b4d-3fd6-44d4-b1be-9f0e43c628cevalue-2e77a0a1-db6a-47a2-a8bf-1e003be6423bseed",
"type": "default",
"source": "8b260b4d-3fd6-44d4-b1be-9f0e43c628ce",
"target": "2e77a0a1-db6a-47a2-a8bf-1e003be6423b",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
}

View File

@ -20,6 +20,7 @@
"category": "default",
"version": "2.0.0"
},
"id": "d1609af5-eb0a-4f73-b573-c9af96a8d6bf",
"nodes": [
{
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
@ -72,12 +73,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -200
},
"width": 320,
"height": 24
}
},
{
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
@ -167,12 +168,12 @@
}
}
},
"width": 320,
"height": 580,
"position": {
"x": 475,
"y": -400
},
"width": 320,
"height": 506
}
},
{
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
@ -232,12 +233,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -400
},
"width": 320,
"height": 24
}
},
{
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
@ -303,12 +304,12 @@
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 0,
"y": -375
},
"width": 320,
"height": 193
}
},
{
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
@ -361,12 +362,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -275
},
"width": 320,
"height": 24
}
},
{
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
@ -464,12 +465,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": 25
},
"width": 320,
"height": 24
}
},
{
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
@ -523,12 +524,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -50
},
"width": 320,
"height": 24
}
},
{
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
@ -635,12 +636,12 @@
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 2037.861329274915,
"y": -329.8393457509562
},
"width": 320,
"height": 224
}
},
{
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
@ -653,7 +654,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.1",
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
@ -861,112 +862,112 @@
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 1570.9941088179146,
"y": -407.6505491604564
},
"width": 320,
"height": 612
}
}
],
"edges": [
{
"id": "1b89067c-3f6b-42c8-991f-e3055789b251-fc9d0e35-a6de-4a19-84e1-c72497c823f6-collapsed",
"type": "collapsed",
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6"
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "collapsed"
},
{
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77-collapsed",
"type": "collapsed",
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77"
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "collapsed"
},
{
"id": "reactflow__edge-1b7e0df8-8589-4915-a4ea-c0088f15d642collection-1b89067c-3f6b-42c8-991f-e3055789b251collection",
"type": "default",
"source": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"target": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "default",
"sourceHandle": "collection",
"targetHandle": "collection"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-fc9d0e35-a6de-4a19-84e1-c72497c823f6clip",
"type": "default",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-1b89067c-3f6b-42c8-991f-e3055789b251item-fc9d0e35-a6de-4a19-84e1-c72497c823f6prompt",
"type": "default",
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "default",
"sourceHandle": "item",
"targetHandle": "prompt"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-c2eaf1ba-5708-4679-9e15-945b8b432692clip",
"type": "default",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5value-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77seed",
"type": "default",
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-fc9d0e35-a6de-4a19-84e1-c72497c823f6conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5epositive_conditioning",
"type": "default",
"source": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-c2eaf1ba-5708-4679-9e15-945b8b432692conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enegative_conditioning",
"type": "default",
"source": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77noise-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enoise",
"type": "default",
"source": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426unet-2fb1577f-0a56-4f12-8711-8afcaaaf1d5eunet",
"type": "default",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-2fb1577f-0a56-4f12-8711-8afcaaaf1d5elatents-491ec988-3c77-4c37-af8a-39a0c4e7a2a1latents",
"type": "default",
"source": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426vae-491ec988-3c77-4c37-af8a-39a0c4e7a2a1vae",
"type": "default",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
}

View File

@ -25,9 +25,10 @@
}
],
"meta": {
"category": "default",
"version": "2.0.0"
"version": "2.0.0",
"category": "default"
},
"id": "a9d70c39-4cdd-4176-9942-8ff3fe32d3b1",
"nodes": [
{
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
@ -79,12 +80,12 @@
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 3425,
"y": -300
},
"width": 320,
"height": 219
}
},
{
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
@ -149,12 +150,12 @@
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 2500,
"y": -600
},
"width": 320,
"height": 193
}
},
{
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
@ -242,12 +243,12 @@
}
}
},
"width": 320,
"height": 252,
"position": {
"x": 2975,
"y": -600
},
"width": 320,
"height": 218
}
},
{
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
@ -299,12 +300,12 @@
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 3425,
"y": -575
},
"width": 320,
"height": 219
}
},
{
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
@ -317,7 +318,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.1",
"version": "1.5.0",
"inputs": {
"positive_conditioning": {
"id": "025ff44b-c4c6-4339-91b4-5f461e2cadc5",
@ -524,12 +525,12 @@
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 3975,
"y": -575
},
"width": 320,
"height": 612
}
},
{
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
@ -626,12 +627,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 3425,
"y": 75
},
"width": 320,
"height": 24
}
},
{
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
@ -684,12 +685,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 3425,
"y": 0
},
"width": 320,
"height": 24
}
},
{
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
@ -795,106 +796,106 @@
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 4450,
"y": -550
},
"width": 320,
"height": 224
}
}
],
"edges": [
{
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953-ea18915f-2c5b-4569-b725-8e9e9122e8d3-collapsed",
"type": "collapsed",
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3"
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "collapsed"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818clip-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip",
"type": "default",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-c3fa6872-2599-4a82-a596-b3446a66cf8bclip",
"type": "default",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818unet-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet",
"type": "default",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63unet",
"type": "default",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-85b77bb2-c67a-416a-b3e8-291abe746c44conditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63negative_conditioning",
"type": "default",
"source": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c3fa6872-2599-4a82-a596-b3446a66cf8bconditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63positive_conditioning",
"type": "default",
"source": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-ea18915f-2c5b-4569-b725-8e9e9122e8d3noise-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63noise",
"type": "default",
"source": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-6fd74a17-6065-47a5-b48b-f4e2b8fa7953value-ea18915f-2c5b-4569-b725-8e9e9122e8d3seed",
"type": "default",
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63latents-a9683c0a-6b1f-4a5e-8187-c57e764b3400latents",
"type": "default",
"source": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818vae-a9683c0a-6b1f-4a5e-8187-c57e764b3400vae",
"type": "default",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-85b77bb2-c67a-416a-b3e8-291abe746c44clip",
"type": "default",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
}

View File

@ -84,12 +84,12 @@
}
}
},
"width": 320,
"height": 259,
"position": {
"x": 1000,
"y": 350
},
"width": 320,
"height": 219
}
},
{
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
@ -187,12 +187,12 @@
}
}
},
"width": 320,
"height": 388,
"position": {
"x": 600,
"y": 325
},
"width": 320,
"height": 388
}
},
{
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
@ -258,12 +258,12 @@
}
}
},
"width": 320,
"height": 226,
"position": {
"x": 600,
"y": 25
},
"width": 320,
"height": 193
}
},
{
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
@ -316,12 +316,12 @@
}
}
},
"width": 320,
"height": 259,
"position": {
"x": 1000,
"y": 25
},
"width": 320,
"height": 219
}
},
{
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
@ -375,12 +375,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 600,
"y": 275
},
"width": 320,
"height": 32
}
},
{
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
@ -393,7 +393,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.1",
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
@ -601,12 +601,12 @@
}
}
},
"width": 320,
"height": 703,
"position": {
"x": 1400,
"y": 25
},
"width": 320,
"height": 612
}
},
{
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
@ -713,86 +713,86 @@
}
}
},
"width": 320,
"height": 266,
"position": {
"x": 1800,
"y": 25
},
"width": 320,
"height": 224
}
}
],
"edges": [
{
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
"type": "default",
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
"type": "default",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
"type": "default",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
"type": "default",
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
"type": "default",
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
"type": "default",
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
"type": "default",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
"type": "default",
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
"type": "default",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
}
]
}
}

View File

@ -80,12 +80,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 750,
"y": -225
},
"width": 320,
"height": 24
}
},
{
"id": "719dabe8-8297-4749-aea1-37be301cd425",
@ -126,12 +126,12 @@
}
}
},
"width": 320,
"height": 258,
"position": {
"x": 750,
"y": -125
},
"width": 320,
"height": 219
}
},
{
"id": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
@ -279,12 +279,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 750,
"y": 200
},
"width": 320,
"height": 24
}
},
{
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
@ -382,12 +382,12 @@
}
}
},
"width": 320,
"height": 388,
"position": {
"x": 375,
"y": 0
},
"width": 320,
"height": 336
}
},
{
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
@ -441,12 +441,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 375,
"y": -50
},
"width": 320,
"height": 24
}
},
{
"id": "30d3289c-773c-4152-a9d2-bd8a99c8fd22",
@ -471,7 +471,8 @@
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SDXLMainModelField"
}
},
"value": null
}
},
"outputs": {
@ -517,12 +518,12 @@
}
}
},
"width": 320,
"height": 257,
"position": {
"x": 375,
"y": -500
},
"width": 320,
"height": 219
}
},
{
"id": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
@ -670,12 +671,12 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 750,
"y": -175
},
"width": 320,
"height": 24
}
},
{
"id": "63e91020-83b2-4f35-b174-ad9692aabb48",
@ -782,12 +783,12 @@
}
}
},
"width": 320,
"height": 266,
"position": {
"x": 1475,
"y": -500
},
"width": 320,
"height": 224
}
},
{
"id": "50a36525-3c0a-4cc5-977c-e4bfc3fd6dfb",
@ -800,7 +801,7 @@
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.1",
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
@ -1008,12 +1009,12 @@
}
}
},
"width": 320,
"height": 702,
"position": {
"x": 1125,
"y": -500
},
"width": 320,
"height": 612
}
},
{
"id": "0093692f-9cf4-454d-a5b8-62f0e3eb3bb8",
@ -1037,7 +1038,8 @@
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VAEModelField"
}
},
"value": null
}
},
"outputs": {
@ -1053,12 +1055,12 @@
}
}
},
"width": 320,
"height": 161,
"position": {
"x": 375,
"y": -225
},
"width": 320,
"height": 139
}
},
{
"id": "ade2c0d3-0384-4157-b39b-29ce429cfa15",
@ -1099,12 +1101,12 @@
}
}
},
"width": 320,
"height": 258,
"position": {
"x": 750,
"y": -500
},
"width": 320,
"height": 219
}
},
{
"id": "ad8fa655-3a76-43d0-9c02-4d7644dea650",
@ -1157,162 +1159,162 @@
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 750,
"y": 150
},
"width": 320,
"height": 24
}
}
],
"edges": [
{
"id": "3774ec24-a69e-4254-864c-097d07a6256f-faf965a4-7530-427b-b1f3-4ba6505c2a08-collapsed",
"type": "collapsed",
"source": "3774ec24-a69e-4254-864c-097d07a6256f",
"target": "faf965a4-7530-427b-b1f3-4ba6505c2a08"
"target": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
"type": "collapsed"
},
{
"id": "ad8fa655-3a76-43d0-9c02-4d7644dea650-3193ad09-a7c2-4bf4-a3a9-1c61cc33a204-collapsed",
"type": "collapsed",
"source": "ad8fa655-3a76-43d0-9c02-4d7644dea650",
"target": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204"
"target": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
"type": "collapsed"
},
{
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
"type": "default",
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-30d3289c-773c-4152-a9d2-bd8a99c8fd22clip-faf965a4-7530-427b-b1f3-4ba6505c2a08clip",
"type": "default",
"source": "30d3289c-773c-4152-a9d2-bd8a99c8fd22",
"target": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-30d3289c-773c-4152-a9d2-bd8a99c8fd22clip2-faf965a4-7530-427b-b1f3-4ba6505c2a08clip2",
"type": "default",
"source": "30d3289c-773c-4152-a9d2-bd8a99c8fd22",
"target": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
"type": "default",
"sourceHandle": "clip2",
"targetHandle": "clip2"
},
{
"id": "reactflow__edge-30d3289c-773c-4152-a9d2-bd8a99c8fd22clip-3193ad09-a7c2-4bf4-a3a9-1c61cc33a204clip",
"type": "default",
"source": "30d3289c-773c-4152-a9d2-bd8a99c8fd22",
"target": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-30d3289c-773c-4152-a9d2-bd8a99c8fd22clip2-3193ad09-a7c2-4bf4-a3a9-1c61cc33a204clip2",
"type": "default",
"source": "30d3289c-773c-4152-a9d2-bd8a99c8fd22",
"target": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
"type": "default",
"sourceHandle": "clip2",
"targetHandle": "clip2"
},
{
"id": "reactflow__edge-30d3289c-773c-4152-a9d2-bd8a99c8fd22unet-50a36525-3c0a-4cc5-977c-e4bfc3fd6dfbunet",
"type": "default",
"source": "30d3289c-773c-4152-a9d2-bd8a99c8fd22",
"target": "50a36525-3c0a-4cc5-977c-e4bfc3fd6dfb",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-faf965a4-7530-427b-b1f3-4ba6505c2a08conditioning-50a36525-3c0a-4cc5-977c-e4bfc3fd6dfbpositive_conditioning",
"type": "default",
"source": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
"target": "50a36525-3c0a-4cc5-977c-e4bfc3fd6dfb",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-3193ad09-a7c2-4bf4-a3a9-1c61cc33a204conditioning-50a36525-3c0a-4cc5-977c-e4bfc3fd6dfbnegative_conditioning",
"type": "default",
"source": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
"target": "50a36525-3c0a-4cc5-977c-e4bfc3fd6dfb",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-50a36525-3c0a-4cc5-977c-e4bfc3fd6dfbnoise",
"type": "default",
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
"target": "50a36525-3c0a-4cc5-977c-e4bfc3fd6dfb",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-50a36525-3c0a-4cc5-977c-e4bfc3fd6dfblatents-63e91020-83b2-4f35-b174-ad9692aabb48latents",
"type": "default",
"source": "50a36525-3c0a-4cc5-977c-e4bfc3fd6dfb",
"target": "63e91020-83b2-4f35-b174-ad9692aabb48",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-0093692f-9cf4-454d-a5b8-62f0e3eb3bb8vae-63e91020-83b2-4f35-b174-ad9692aabb48vae",
"type": "default",
"source": "0093692f-9cf4-454d-a5b8-62f0e3eb3bb8",
"target": "63e91020-83b2-4f35-b174-ad9692aabb48",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-ade2c0d3-0384-4157-b39b-29ce429cfa15value-faf965a4-7530-427b-b1f3-4ba6505c2a08prompt",
"type": "default",
"source": "ade2c0d3-0384-4157-b39b-29ce429cfa15",
"target": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
"type": "default",
"sourceHandle": "value",
"targetHandle": "prompt"
},
{
"id": "reactflow__edge-719dabe8-8297-4749-aea1-37be301cd425value-3193ad09-a7c2-4bf4-a3a9-1c61cc33a204prompt",
"type": "default",
"source": "719dabe8-8297-4749-aea1-37be301cd425",
"target": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
"type": "default",
"sourceHandle": "value",
"targetHandle": "prompt"
},
{
"id": "reactflow__edge-719dabe8-8297-4749-aea1-37be301cd425value-ad8fa655-3a76-43d0-9c02-4d7644dea650string_left",
"type": "default",
"source": "719dabe8-8297-4749-aea1-37be301cd425",
"target": "ad8fa655-3a76-43d0-9c02-4d7644dea650",
"type": "default",
"sourceHandle": "value",
"targetHandle": "string_left"
},
{
"id": "reactflow__edge-ad8fa655-3a76-43d0-9c02-4d7644dea650value-3193ad09-a7c2-4bf4-a3a9-1c61cc33a204style",
"type": "default",
"source": "ad8fa655-3a76-43d0-9c02-4d7644dea650",
"target": "3193ad09-a7c2-4bf4-a3a9-1c61cc33a204",
"type": "default",
"sourceHandle": "value",
"targetHandle": "style"
},
{
"id": "reactflow__edge-ade2c0d3-0384-4157-b39b-29ce429cfa15value-3774ec24-a69e-4254-864c-097d07a6256fstring_left",
"type": "default",
"source": "ade2c0d3-0384-4157-b39b-29ce429cfa15",
"target": "3774ec24-a69e-4254-864c-097d07a6256f",
"type": "default",
"sourceHandle": "value",
"targetHandle": "string_left"
},
{
"id": "reactflow__edge-3774ec24-a69e-4254-864c-097d07a6256fvalue-faf965a4-7530-427b-b1f3-4ba6505c2a08style",
"type": "default",
"source": "3774ec24-a69e-4254-864c-097d07a6256f",
"target": "faf965a4-7530-427b-b1f3-4ba6505c2a08",
"type": "default",
"sourceHandle": "value",
"targetHandle": "style"
}
]
}
}

View File

@ -31,7 +31,6 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class WorkflowMeta(BaseModel):

View File

@ -1,109 +0,0 @@
import pathlib
from typing import Literal, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.util import download_with_progress_bar
config = InvokeAIAppConfig.get_config()
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
}
transform = Compose(
[
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
def load_model(self, model_size=Literal["large", "base", "small"]):
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
if not DEPTH_ANYTHING_MODEL_PATH.exists():
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
if not self.model or model_size != self.model_size:
del self.model
self.model_size = model_size
match self.model_size:
case "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
case _:
raise TypeError("Not a supported model")
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())
return self.model
def to(self, device):
self.model.to(device)
return self
def __call__(self, image, resolution=512, offload=False):
image = np.array(image, dtype=np.uint8)
image = image[:, :, ::-1] / 255.0
image_height, image_width = image.shape[:2]
image = transform({"image": image})["image"]
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
with torch.no_grad():
depth = self.model(image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
depth_map = Image.fromarray(depth_map)
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
if offload:
del self.model
return depth_map

View File

@ -1,145 +0,0 @@
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@ -1,183 +0,0 @@
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from .blocks import FeatureFusionBlock, _make_scratch
torchhub_path = Path(__file__).parent.parent / "torchhub"
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class DPTHead(nn.Module):
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList(
[
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
if nclass > 1:
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DPT_DINOv2(nn.Module):
def __init__(
self,
features,
out_channels,
encoder="vitl",
use_bn=False,
use_clstoken=False,
):
super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally
# if use_local:
# self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder),
# source="local",
# pretrained=False,
# )
# else:
# self.pretrained = torch.hub.load(
# "facebookresearch/dinov2",
# "dinov2_{:}14".format(encoder),
# )
self.pretrained = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_{:}14".format(encoder),
)
dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x):
h, w = x.shape[-2:]
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
patch_h, patch_w = h // 14, w // 14
depth = self.depth_head(features, patch_h, patch_w)
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
depth = F.relu(depth)
return depth.squeeze(1)

View File

@ -1,227 +0,0 @@
import math
import cv2
import numpy as np
import torch
import torch.nn.functional as F
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height)."""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "semseg_mask" in sample:
# sample["semseg_mask"] = cv2.resize(
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
# )
sample["semseg_mask"] = F.interpolate(
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
).numpy()[0, 0]
if "mask" in sample:
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
# sample["mask"] = sample["mask"].astype(bool)
# print(sample['image'].shape, sample['depth'].shape)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std."""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input."""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "semseg_mask" in sample:
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
return sample

View File

@ -283,16 +283,11 @@ class ModelInstall(object):
def _remove_installed(self, model_list: List[str]):
all_models = self.all_models()
models_to_remove = []
for path in model_list:
key = self.reverse_paths.get(path)
if key and all_models[key].installed:
models_to_remove.append(path)
for path in models_to_remove:
logger.warning(f"{path} already installed. Skipping")
model_list.remove(path)
logger.warning(f"{path} already installed. Skipping.")
model_list.remove(path)
def _add_required_models(self, model_list: List[str]):
additional_models = []

View File

@ -759,7 +759,7 @@ class ModelManager(object):
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
) -> None:
):
"""
Rename or rebase a model.
"""
@ -781,9 +781,6 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
# keep the suffix!
if old_path.is_file():
new_name = Path(new_name).with_suffix(old_path.suffix).as_posix()
new_path = self.resolve_model_path(
Path(
BaseModelType(new_base).value,

View File

@ -6,7 +6,6 @@ from .config import (
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
ModelType,
ModelVariantType,
SchedulerPredictionType,
@ -16,16 +15,15 @@ from .probe import ModelProbe
from .search import ModelSearch
__all__ = [
"AnyModelConfig",
"BaseModelType",
"ModelRepoVariant",
"InvalidModelConfigException",
"ModelConfigFactory",
"ModelFormat",
"ModelProbe",
"ModelSearch",
"InvalidModelConfigException",
"ModelConfigFactory",
"BaseModelType",
"ModelType",
"ModelVariantType",
"SchedulerPredictionType",
"SubModelType",
"ModelVariantType",
"ModelFormat",
"SchedulerPredictionType",
"AnyModelConfig",
]

View File

@ -99,17 +99,6 @@ class SchedulerPredictionType(str, Enum):
Sample = "sample"
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
DEFAULT = "default" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OPENVINO = "openvino"
FLAX = "flax"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""

View File

@ -1,50 +0,0 @@
"""
Initialization file for invokeai.backend.model_manager.metadata
Usage:
from invokeai.backend.model_manager.metadata import(
AnyModelRepoMetadata,
CommercialUsage,
LicenseRestrictions,
HuggingFaceMetadata,
CivitaiMetadata,
)
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
BaseMetadata,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
ModelMetadataWithFiles,
RemoteModelFile,
UnknownMetadataException,
)
from .metadata_store import ModelMetadataStore
__all__ = [
"AnyModelRepoMetadata",
"AnyModelRepoMetadataValidator",
"CivitaiMetadata",
"CivitaiMetadataFetch",
"CommercialUsage",
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataStore",
"BaseMetadata",
"ModelMetadataWithFiles",
"RemoteModelFile",
"UnknownMetadataException",
]

View File

@ -1,21 +0,0 @@
"""
Initialization file for invokeai.backend.model_manager.metadata.fetch
Usage:
from invokeai.backend.model_manager.metadata.fetch import (
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
)
from invokeai.backend.model_manager.metadata import CivitaiMetadata
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .civitai import CivitaiMetadataFetch
from .fetch_base import ModelMetadataFetchBase
from .huggingface import HuggingFaceMetadataFetch
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]

View File

@ -1,187 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the Civitai model repository.
In addition to the `from_url()` and `from_id()` methods inherited from the
`ModelMetadataFetchBase` base class.
Civitai has two separate ID spaces: a model ID and a version ID. The
version ID corresponds to a specific model, and is the ID accepted by
`from_id()`. The model ID corresponds to a family of related models,
such as different training checkpoints or 16 vs 32-bit versions. The
`from_civitai_modelid()` method will accept a model ID and return the
metadata from the default version within this model set. The default
version is the same as what the user sees when they click on a model's
thumbnail.
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
import requests
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from ..metadata_base import (
AnyModelRepoMetadata,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
RemoteModelFile,
UnknownMetadataException,
)
from .fetch_base import ModelMetadataFetchBase
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
In the event that the URL points to a model page without the particular version
indicated, the default model version is returned. Otherwise, the requested version
is returned.
"""
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
model_id = match.group(1)
version_id = match.group(2)
return self.from_civitai_versionid(int(version_id), int(model_id))
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given a Civitai model version ID, return a ModelRepoMetadata object.
May raise an `UnknownMetadataException`.
"""
return self.from_civitai_versionid(int(id))
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""
Return metadata from the default version of the indicated model.
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try:
version_id = version_id or model_json["modelVersions"][0]["id"]
except TypeError as excp:
raise UnknownMetadataException from excp
# loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")]
assert len(primary) == 1
primary_file = primary[0]
url = primary_file["downloadUrl"]
if "?" not in url: # work around apparent bug in civitai api
metadata_string = ""
for key, value in primary_file["metadata"].items():
if not value:
continue
metadata_string += f"&{key}={value}"
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=url,
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
)
]
return CivitaiMetadata(
id=model_json["id"],
name=version_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
files=model_files,
download_url=version_json["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
"""
Return a CivitaiMetadata object given a model version id.
May raise an `UnknownMetadataException`.
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _fix_timezone(date: str) -> str:
return re.sub(r"Z$", "+00:00", date)

View File

@ -1,61 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module is the base class for subclasses that fetch metadata from model repositories
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
from abc import ABC, abstractmethod
from typing import Optional
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
class ModelMetadataFetchBase(ABC):
"""Fetch metadata from remote generative model repositories."""
@abstractmethod
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
pass
@abstractmethod
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a model repository, return a ModelMetadata object.
This method will raise a `UnknownMetadataException`
in the event that the requested model metadata is not found at the provided location.
"""
pass
@abstractmethod
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given an ID for a model, return a ModelMetadata object.
This method will raise a `UnknownMetadataException`
in the event that the requested model's metadata is not found at the provided id.
"""
pass
@classmethod
def from_json(cls, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
return metadata

View File

@ -1,92 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the HuggingFace model repository,
using either a `repo_id` or the model page URL.
Usage:
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
fetcher = HuggingFaceMetadataFetch()
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import re
from pathlib import Path
from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.utils._errors import RepositoryNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from ..metadata_base import (
AnyModelRepoMetadata,
HuggingFaceMetadata,
RemoteModelFile,
UnknownMetadataException,
)
from .fetch_base import ModelMetadataFetchBase
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from HuggingFace."""
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory=lambda: self._requests)
@classmethod
def from_json(cls, json: str) -> HuggingFaceMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = HuggingFaceMetadata.model_validate_json(json)
return metadata
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id,
author=model_info.author,
name=name,
last_modified=model_info.last_modified,
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
tags=model_info.tags,
files=[
RemoteModelFile(
url=hf_hub_url(id, x.rfilename),
path=Path(name, x.rfilename),
size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None,
)
for x in model_info.siblings
],
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Return a HuggingFaceMetadata object given the model's web page URL.
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
if match := re.match(HF_MODEL_RE, str(url), re.IGNORECASE):
repo_id = match.group(1)
return self.from_id(repo_id)
else:
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")

View File

@ -1,202 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""This module defines core text-to-image model metadata fields.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in invokeai.backend.model_manager.config.
Note that the "name" and "description" are also present in `config`
records. This is intentional. The config record fields are intended to
be editable by the user as a form of customization. The metadata
versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
from invokeai.backend.model_manager import ModelRepoVariant
from ..util import select_hf_files
class UnknownMetadataException(Exception):
"""Raised when no metadata is available for a model."""
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: CommercialUsage = Field(
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
)
class RemoteModelFile(BaseModel):
"""Information about a downloadable file that forms part of a model."""
url: AnyHttpUrl = Field(description="The url to download this model file")
path: Path = Field(description="The path to the file, relative to the model root")
size: int = Field(description="The size of this file, in bytes")
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class BaseMetadata(ModelMetadataBase):
"""Adds typing data for discriminated union."""
type: Literal["basemetadata"] = "basemetadata"
class ModelMetadataWithFiles(ModelMetadataBase):
"""Base class for metadata that contains a list of downloadable model file(s)."""
files: List[RemoteModelFile] = Field(description="model files and their sizes", default_factory=list)
def download_urls(
self,
variant: Optional[ModelRepoVariant] = None,
subfolder: Optional[Path] = None,
session: Optional[Session] = None,
) -> List[RemoteModelFile]:
"""
Return a list of URLs needed to download the model.
:param variant: Return files needed to reconstruct the indicated variant (e.g. ModelRepoVariant('fp16'))
:param subfolder: Return files in the designated subfolder only
:param session: A request.Session object for offline testing
Note that the "variant" and "subfolder" concepts currently only apply to HuggingFace.
However Civitai does have fields for the precision and format of its models, and may
provide variant selection criteria in the future.
"""
return self.files
class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai version identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_minmax: Tuple[float, float] = Field(
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
) # note: For future use
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
def download_urls(
self,
variant: Optional[ModelRepoVariant] = None,
subfolder: Optional[Path] = None,
session: Optional[Session] = None,
) -> List[RemoteModelFile]:
"""
Return list of downloadable files, filtering by variant and subfolder, if any.
:param variant: Return model files needed to reconstruct the indicated variant
:param subfolder: Return model files from the designated subfolder only
:param session: A request.Session object used for internet-free testing
Note that there is special variant-filtering behavior here:
When the fp16 variant is requested and not available, the
full-precision model is returned.
"""
session = session or Session()
configure_http_backend(backend_factory=lambda: session) # used in testing
paths = select_hf_files.filter_files(
[x.path for x in self.files], variant, subfolder
) # all files in the model
prefix = f"{subfolder}/" if subfolder else ""
# the next step reads model_index.json to determine which subdirectories belong
# to the model
if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
resp = session.get(url)
resp.raise_for_status()
submodels = resp.json()
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, Path(f"{prefix}model_index.json"))
return [x for x in self.files if x.path in paths]
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)

View File

@ -1,221 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -496,9 +496,9 @@ class PipelineFolderProbe(FolderProbeBase):
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction":
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon":
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")

View File

@ -1,132 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Select the files from a HuggingFace repository needed for a particular model variant.
Usage:
```
from invokeai.backend.model_manager.util.select_hf_files import select_hf_model_files
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
metadata = HuggingFaceMetadataFetch().from_url("https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0")
files_to_download = select_hf_model_files(metadata.files, variant='onnx')
```
"""
import re
from pathlib import Path
from typing import Dict, List, Optional, Set
from ..config import ModelRepoVariant
def filter_files(
files: List[Path],
variant: Optional[ModelRepoVariant] = None,
subfolder: Optional[Path] = None,
) -> List[Path]:
"""
Take a list of files in a HuggingFace repo root and return paths to files needed to load the model.
:param files: List of files relative to the repo root.
:param subfolder: Filter by the indicated subfolder.
:param variant: Filter by files belonging to a particular variant, such as fp16.
The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
"""
variant = variant or ModelRepoVariant.DEFAULT
paths: List[Path] = []
# Start by filtering on model file extensions, discarding images, docs, etc
for file in files:
if file.name.endswith((".json", ".txt")):
paths.append(file)
elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")):
paths.append(file)
# BRITTLENESS WARNING!!
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area of brittleness.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = {}
for path in files:
if path.suffix == ".onnx":
if variant == ModelRepoVariant.ONNX:
result.add(path)
elif "openvino_model" in path.name:
if variant == ModelRepoVariant.OPENVINO:
result.add(path)
elif "flax_model" in path.name:
if variant == ModelRepoVariant.FLAX:
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
ModelRepoVariant.FP16,
ModelRepoVariant.FP32,
ModelRepoVariant.DEFAULT,
]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
variant_label, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
variant_label = ""
suffix = suffixes[0]
basename = parent / path.stem
if previous := basenames.get(basename):
if (
previous.suffix != ".safetensors" and suffix == ".safetensors"
): # replace non-safetensors with safetensors when available
basenames[basename] = path
if variant_label == f".{variant}":
basenames[basename] = path
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
basenames[basename] = path
else:
basenames[basename] = path
else:
continue
for v in basenames.values():
result.add(v)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list
if (
variant
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
and not any(variant.value in x.name for x in result)
):
return set()
# Prune folders that contain just a `config.json`. This happens when
# the requested variant (e.g. "onnx") is missing
directories: Dict[Path, int] = {}
for x in result:
if not x.parent:
continue
directories[x.parent] = directories.get(x.parent, 0) + 1
return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}

View File

@ -34,23 +34,18 @@ def choose_precision(device: torch.device) -> str:
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
if config.precision == "bfloat16":
return "bfloat16"
else:
return "float16"
return "float16"
elif device.type == "mps":
return "float16"
return "float32"
def torch_dtype(device: torch.device) -> torch.dtype:
precision = choose_precision(device)
if precision == "float16":
return torch.float16
if precision == "bfloat16":
return torch.bfloat16
if config.full_precision:
return torch.float32
if choose_precision(device) == "float16":
return torch.bfloat16 if device.type == "cuda" else torch.float16
else:
# "auto", "autocast", "float32"
return torch.float32

View File

@ -93,27 +93,6 @@ module.exports = {
'@typescript-eslint/no-import-type-side-effects': 'error',
'simple-import-sort/imports': 'error',
'simple-import-sort/exports': 'error',
// Prefer @invoke-ai/ui components over chakra
'no-restricted-imports': 'off',
'@typescript-eslint/no-restricted-imports': [
'warn',
{
paths: [
{
name: '@chakra-ui/react',
message: "Please import from '@invoke-ai/ui' instead.",
},
{
name: '@chakra-ui/layout',
message: "Please import from '@invoke-ai/ui' instead.",
},
{
name: '@chakra-ui/portal',
message: "Please import from '@invoke-ai/ui' instead.",
},
],
},
],
},
overrides: [
{

View File

@ -1,7 +1,7 @@
import { PropsWithChildren, memo, useEffect } from 'react';
import { modelChanged } from '../src/features/parameters/store/generationSlice';
import { useAppDispatch } from '../src/app/store/storeHooks';
import { useGlobalModifiersInit } from '@invoke-ai/ui';
import { useGlobalModifiersInit } from '../src/common/hooks/useGlobalModifiers';
/**
* Initializes some state for storybook. Must be in a different component
* so that it is run inside the redux context.

View File

@ -6,6 +6,7 @@ import { Provider } from 'react-redux';
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
import { createStore } from '../src/app/store/store';
import { Container } from '@chakra-ui/react';
// TODO: Disabled for IDE performance issues with our translation JSON
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore

View File

@ -1,51 +1,79 @@
# Invoke UI
# InvokeAI Web UI
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
<!-- code_chunk_output -->
- [Invoke UI](#invoke-ui)
- [InvokeAI Web UI](#invokeai-web-ui)
- [Core Libraries](#core-libraries)
- [Redux Toolkit](#redux-toolkit)
- [Socket\.IO](#socketio)
- [Chakra UI](#chakra-ui)
- [KonvaJS](#konvajs)
- [Vite](#vite)
- [i18next & Weblate](#i18next--weblate)
- [openapi-typescript](#openapi-typescript)
- [reactflow](#reactflow)
- [zod](#zod)
- [Client Types Generation](#client-types-generation)
- [Package Scripts](#package-scripts)
- [Client Types Generation](#client-types-generation)
- [Contributing](#contributing)
- [Localization](#localization)
- [Dev Environment](#dev-environment)
- [VSCode Remote Dev](#vscode-remote-dev)
- [VSCode Remote Dev](#vscode-remote-dev)
- [Production builds](#production-builds)
<!-- /code_chunk_output -->
The UI is a fairly straightforward Typescript React app.
## Core Libraries
Invoke's UI is made possible by a number of excellent open-source libraries. The most heavily-used are listed below, but there are many others.
InvokeAI's UI is made possible by a number of excellent open-source libraries. The most heavily-used are listed below, but there are many others.
- [Redux Toolkit]
- [redux-remember]
- [Socket.IO]
- [Chakra UI]
- [KonvaJS]
- [Vite]
- [openapi-typescript]
- [reactflow]
- [zod]
### Redux Toolkit
## Package Scripts
[Redux Toolkit] is used for state management and fetching/caching:
See `package.json` for all scripts.
- `RTK-Query` for data fetching and caching
- `createAsyncThunk` for a couple other HTTP requests
- `createEntityAdapter` to normalize things like images and models
- `createListenerMiddleware` for async workflows
Run with `pnpm <script name>`.
We use [redux-remember] for persistence.
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Client Types Generation])
- `lint:madge`: check frontend for circular dependencies
- `lint:eslint`: check frontend for code quality
- `lint:prettier`: check frontend for code formatting
- `lint:tsc`: check frontend for type issues
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
### Socket\.IO
### Client Types Generation
[Socket.IO] is used for server-to-client events, like generation process and queue state changes.
### Chakra UI
[Chakra UI] is our primary UI library, but we also use a few components from [Mantine v6].
### KonvaJS
[KonvaJS] powers the canvas. In the future, we'd like to explore [PixiJS] or WebGPU.
### Vite
[Vite] is our bundler.
### i18next & Weblate
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project. **Only the English source strings should be changed on this repo.**
### openapi-typescript
[openapi-typescript] is used to generate types from the server's OpenAPI schema. See TYPES_CODEGEN.md.
### reactflow
[reactflow] powers the Workflow Editor.
### zod
[zod] schemas are used to model data structures and provide runtime validation.
## Client Types Generation
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
@ -60,18 +88,28 @@ python scripts/invokeai-web.py
pnpm typegen
```
## Package Scripts
See `package.json` for all scripts.
Run with `pnpm <script name>`.
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Client Types Generation](#client-types-generation))
- `lint:madge`: check frontend for circular dependencies
- `lint:eslint`: check frontend for code quality
- `lint:prettier`: check frontend for code formatting
- `lint:tsc`: check frontend for type issues
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
## Contributing
Thanks for your interest in contributing to the Invoke Web UI!
Thanks for your interest in contributing to the InvokeAI Web UI!
We encourage you to ping @psychedelicious and @blessedcoolant on [discord] if you want to contribute, just to touch base and ensure your work doesn't conflict with anything else going on. The project is very active.
### Localization
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
**Only the English source strings should be changed on this repo.**
### Dev Environment
Install [node] and [pnpm].
@ -80,19 +118,23 @@ From `invokeai/frontend/web/` run `pnpm i` to get everything set up.
Start everything in dev mode:
1. From `invokeai/frontend/web/`: `pnpm dev`
2. From repo root: `python scripts/invokeai-web.py`
1. Start the dev server: `pnpm dev`
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### VSCode Remote Dev
#### VSCode Remote Dev
We've noticed an intermittent issue with the VSCode Remote Dev port forwarding. If you use this feature of VSCode, you may intermittently click the Invoke button and then get nothing until the request times out.
We've noticed an intermittent issue with the VSCode Remote Dev port forwarding. If you use this feature of VSCode, you may intermittently click the Invoke button and then get nothing until the request times out. Suggest disabling the IDE's port forwarding feature and doing it manually via SSH:
We suggest disabling the IDE's port forwarding feature and doing it manually via SSH:
`ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host`
```sh
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
```
### Production builds
For a number of technical and logistical reasons, we need to commit UI build artefacts to the repo.
If you submit a PR, there is a good chance we will ask you to include a separate commit with a build of the app.
To build for production, run `pnpm build`.
[node]: https://nodejs.org/en/download/
[pnpm]: https://github.com/pnpm/pnpm
@ -101,11 +143,12 @@ ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
[redux-remember]: https://github.com/zewish/redux-remember
[Socket.IO]: https://github.com/socketio/socket.io
[Chakra UI]: https://github.com/chakra-ui/chakra-ui
[Mantine v6]: https://v6.mantine.dev/
[KonvaJS]: https://github.com/konvajs/react-konva
[PixiJS]: https://github.com/pixijs/pixijs
[Vite]: https://github.com/vitejs/vite
[i18next]: https://github.com/i18next/react-i18next
[Weblate]: https://hosted.weblate.org/engage/invokeai/
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
[reactflow]: https://github.com/xyflow/xyflow
[zod]: https://github.com/colinhacks/zod
[Client Types Generation]: #client-types-generation

View File

@ -23,7 +23,7 @@
- [Primitive Types](#primitive-types)
- [Complex Types](#complex-types)
- [Collection Types](#collection-types)
- [Collection or Scalar Types](#collection-or-scalar-types)
- [Polymorphic Types](#polymorphic-types)
- [Optional Fields](#optional-fields)
- [Building Field Input Templates](#building-field-input-templates)
- [Building Field Output Templates](#building-field-output-templates)

View File

@ -32,8 +32,8 @@
"fix": "eslint --fix . && prettier --log-level warn --write .",
"preinstall": "npx only-allow pnpm",
"postinstall": "pnpm run theme",
"theme": "chakra-cli tokens node_modules/@invoke-ai/ui-library",
"theme:watch": "chakra-cli tokens node_modules/@invoke-ai/ui-library --watch",
"theme": "chakra-cli tokens src/theme/theme.ts",
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"unimported": "npx unimported"
@ -52,12 +52,20 @@
}
},
"dependencies": {
"@chakra-ui/anatomy": "^2.2.2",
"@chakra-ui/icons": "^2.1.1",
"@chakra-ui/layout": "^2.3.1",
"@chakra-ui/portal": "^2.1.0",
"@chakra-ui/react": "^2.8.2",
"@chakra-ui/react-use-size": "^2.1.0",
"@chakra-ui/styled-system": "^2.9.2",
"@chakra-ui/theme-tools": "^2.1.2",
"@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/utilities": "^3.2.2",
"@emotion/react": "^11.11.3",
"@emotion/styled": "^11.11.0",
"@fontsource-variable/inter": "^5.0.16",
"@invoke-ai/ui-library": "0.0.18-1a2150a.0",
"@mantine/form": "6.0.21",
"@nanostores/react": "^0.7.1",
"@reduxjs/toolkit": "2.0.1",
@ -65,12 +73,12 @@
"chakra-react-select": "^4.7.6",
"compare-versions": "^6.1.0",
"dateformat": "^5.0.3",
"framer-motion": "^10.18.0",
"framer-motion": "^10.17.9",
"i18next": "^23.7.16",
"i18next-http-backend": "^2.4.2",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.1",
"konva": "^9.3.0",
"lodash-es": "^4.17.21",
"nanostores": "^0.9.5",
"new-github-issue-url": "^1.0.0",
@ -82,29 +90,29 @@
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-error-boundary": "^4.0.12",
"react-hook-form": "^7.49.3",
"react-hotkeys-hook": "4.4.4",
"react-hook-form": "^7.49.2",
"react-hotkeys-hook": "4.4.3",
"react-i18next": "^14.0.0",
"react-icons": "^5.0.1",
"react-icons": "^4.12.0",
"react-konva": "^18.2.10",
"react-redux": "9.1.0",
"react-resizable-panels": "^1.0.9",
"react-redux": "9.0.4",
"react-resizable-panels": "^1.0.8",
"react-select": "5.8.0",
"react-textarea-autosize": "^8.5.3",
"react-use": "^17.4.3",
"react-use": "^17.4.2",
"react-virtuoso": "^4.6.2",
"reactflow": "^11.10.2",
"reactflow": "^11.10.1",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"roarr": "^7.21.0",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.4",
"socket.io-client": "^4.7.3",
"type-fest": "^4.9.0",
"use-debounce": "^10.0.0",
"use-image": "^1.1.1",
"uuid": "^9.0.1",
"zod": "^3.22.4",
"zod-validation-error": "^3.0.0"
"zod-validation-error": "^2.1.0"
},
"peerDependencies": {
"@chakra-ui/cli": "^2.4.1",
@ -116,32 +124,32 @@
"devDependencies": {
"@arthurgeron/eslint-plugin-react-usememo": "^2.2.3",
"@chakra-ui/cli": "^2.4.1",
"@storybook/addon-docs": "^7.6.10",
"@storybook/addon-essentials": "^7.6.10",
"@storybook/addon-interactions": "^7.6.10",
"@storybook/addon-links": "^7.6.10",
"@storybook/addon-storysource": "^7.6.10",
"@storybook/blocks": "^7.6.10",
"@storybook/manager-api": "^7.6.10",
"@storybook/react": "^7.6.10",
"@storybook/react-vite": "^7.6.10",
"@storybook/test": "^7.6.10",
"@storybook/theming": "^7.6.10",
"@storybook/addon-docs": "^7.6.7",
"@storybook/addon-essentials": "^7.6.7",
"@storybook/addon-interactions": "^7.6.7",
"@storybook/addon-links": "^7.6.7",
"@storybook/addon-storysource": "^7.6.7",
"@storybook/blocks": "^7.6.7",
"@storybook/manager-api": "^7.6.7",
"@storybook/react": "^7.6.7",
"@storybook/react-vite": "^7.6.7",
"@storybook/test": "^7.6.7",
"@storybook/theming": "^7.6.7",
"@types/dateformat": "^5.0.2",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.11.5",
"@types/react": "^18.2.48",
"@types/node": "^20.10.7",
"@types/react": "^18.2.47",
"@types/react-dom": "^18.2.18",
"@types/uuid": "^9.0.7",
"@typescript-eslint/eslint-plugin": "^6.19.0",
"@typescript-eslint/parser": "^6.19.0",
"@typescript-eslint/eslint-plugin": "^6.18.0",
"@typescript-eslint/parser": "^6.18.0",
"@vitejs/plugin-react-swc": "^3.5.0",
"concurrently": "^8.2.2",
"eslint": "^8.56.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-i18next": "^6.0.3",
"eslint-plugin-import": "^2.29.1",
"eslint-plugin-path": "^1.2.4",
"eslint-plugin-path": "^1.2.3",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-simple-import-sort": "^10.0.0",
@ -150,16 +158,16 @@
"madge": "^6.1.0",
"openapi-types": "^12.1.3",
"openapi-typescript": "^6.7.3",
"prettier": "^3.2.4",
"prettier": "^3.1.1",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^7.6.10",
"storybook": "^7.6.7",
"ts-toolbelt": "^9.6.0",
"typescript": "^5.3.3",
"vite": "^5.0.11",
"vite-plugin-css-injected-by-js": "^3.3.1",
"vite-plugin-dts": "^3.7.1",
"vite-plugin-dts": "^3.7.0",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.3.1"
"vite-tsconfig-paths": "^4.2.3"
},
"pnpm": {
"patchedDependencies": {

File diff suppressed because it is too large Load Diff

View File

@ -110,28 +110,7 @@
"somethingWentWrong": "Etwas ist schief gelaufen",
"copyError": "$t(gallery.copy) Fehler",
"input": "Eingabe",
"notInstalled": "Nicht $t(common.installed)",
"advancedOptions": "Erweiterte Einstellungen",
"alpha": "Alpha",
"red": "Rot",
"green": "Grün",
"blue": "Blau",
"delete": "Löschen",
"or": "oder",
"direction": "Richtung",
"free": "Frei",
"save": "Speichern",
"preferencesLabel": "Präferenzen",
"created": "Erstellt",
"prevPage": "Vorherige Seite",
"nextPage": "Nächste Seite",
"unknownError": "Unbekannter Fehler",
"unsaved": "Nicht gespeichert",
"aboutDesc": "Verwenden Sie Invoke für die Arbeit? Dann siehe hier:",
"localSystem": "Lokales System",
"orderBy": "Ordnen nach",
"saveAs": "Speicher als",
"updated": "Aktualisiert"
"notInstalled": "Nicht $t(common.installed)"
},
"gallery": {
"generations": "Erzeugungen",
@ -722,8 +701,7 @@
"invokeProgressBar": "Invoke Fortschrittsanzeige",
"mode": "Modus",
"resetUI": "$t(accessibility.reset) von UI",
"createIssue": "Ticket erstellen",
"about": "Über"
"createIssue": "Ticket erstellen"
},
"boards": {
"autoAddBoard": "Automatisches Hinzufügen zum Ordner",
@ -831,8 +809,7 @@
"canny": "Canny",
"hedDescription": "Ganzheitlich verschachtelte Kantenerkennung",
"scribble": "Scribble",
"maxFaces": "Maximal Anzahl Gesichter",
"resizeSimple": "Größe ändern (einfach)"
"maxFaces": "Maximal Anzahl Gesichter"
},
"queue": {
"status": "Status",
@ -1022,27 +999,5 @@
"selectLoRA": "Wählen ein LoRA aus",
"esrganModel": "ESRGAN Modell",
"addLora": "LoRA hinzufügen"
},
"accordions": {
"generation": {
"title": "Erstellung",
"modelTab": "Modell",
"conceptsTab": "Konzepte"
},
"image": {
"title": "Bild"
},
"advanced": {
"title": "Erweitert"
},
"control": {
"title": "Kontrolle",
"controlAdaptersTab": "Kontroll Adapter",
"ipTab": "Bild Beschreibung"
},
"compositing": {
"coherenceTab": "Kohärenzpass",
"infillTab": "Füllung"
}
}
}

View File

@ -1,6 +1,5 @@
{
"accessibility": {
"about": "About",
"copyMetadataJson": "Copy metadata JSON",
"createIssue": "Create Issue",
"exitViewer": "Exit Viewer",
@ -75,8 +74,6 @@
}
},
"common": {
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"advanced": "Advanced",
"advancedOptions": "Advanced Options",
@ -139,7 +136,6 @@
"load": "Load",
"loading": "Loading",
"loadingInvokeAI": "Loading Invoke AI",
"localSystem": "Local System",
"learnMore": "Learn More",
"modelManager": "Model Manager",
"nodeEditor": "Node Editor",
@ -203,11 +199,7 @@
"prevPage": "Previous Page",
"nextPage": "Next Page",
"unknownError": "Unknown Error",
"unsaved": "Unsaved",
"red": "Red",
"green": "Green",
"blue": "Blue",
"alpha": "Alpha"
"unsaved": "Unsaved"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
@ -224,7 +216,6 @@
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
"base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage",
"bgth": "bg_th",
"canny": "Canny",
@ -238,8 +229,6 @@
"controlMode": "Control Mode",
"crop": "Crop",
"delete": "Delete",
"depthAnything": "Depth Anything",
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
"depthMidas": "Depth (Midas)",
"depthMidasDescription": "Depth map generation using Midas",
"depthZoe": "Depth (Zoe)",
@ -259,7 +248,6 @@
"colorMapTileSize": "Tile Size",
"importImageFromCanvas": "Import Image From Canvas",
"importMaskFromCanvas": "Import Mask From Canvas",
"large": "Large",
"lineart": "Lineart",
"lineartAnime": "Lineart Anime",
"lineartAnimeDescription": "Anime-style lineart processing",
@ -272,7 +260,6 @@
"minConfidence": "Min Confidence",
"mlsd": "M-LSD",
"mlsdDescription": "Minimalist Line Segment Detector",
"modelSize": "Model Size",
"none": "None",
"noneDescription": "No processing applied",
"normalBae": "Normal BAE",
@ -293,7 +280,6 @@
"selectModel": "Select a model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced",
"small": "Small",
"toggleControlNet": "Toggle this ControlNet",
"w": "W",
"weight": "Weight",
@ -606,10 +592,6 @@
"desc": "Send current image to Image to Image",
"title": "Send To Image To Image"
},
"remixImage": {
"desc": "Use all parameters except seed from the current image",
"title": "Remix image"
},
"setParameters": {
"desc": "Use all parameters of the current image",
"title": "Set Parameters"
@ -717,7 +699,6 @@
"clearCheckpointFolder": "Clear Checkpoint Folder",
"closeAdvanced": "Close Advanced",
"config": "Config",
"configFile": "Config File",
"configValidationMsg": "Path to the config file of your model.",
"conversionNotSupported": "Conversion Not Supported",
"convert": "Convert",
@ -1108,7 +1089,6 @@
"boundingBoxHeader": "Bounding Box",
"boundingBoxHeight": "Bounding Box Height",
"boundingBoxWidth": "Bounding Box Width",
"boxBlur": "Box Blur",
"cancel": {
"cancel": "Cancel",
"immediate": "Cancel immediately",
@ -1136,7 +1116,6 @@
"enableNoiseSettings": "Enable Noise Settings",
"faceRestoration": "Face Restoration",
"general": "General",
"gaussianBlur": "Gaussian Blur",
"height": "Height",
"hidePreview": "Hide Preview",
"hiresOptim": "High Res Optimization",
@ -1226,7 +1205,6 @@
"useCpuNoise": "Use CPU Noise",
"cpuNoise": "CPU Noise",
"gpuNoise": "GPU Noise",
"remixImage": "Remix Image",
"useInitImg": "Use Initial Image",
"usePrompt": "Use Prompt",
"useSeed": "Use Seed",
@ -1385,7 +1363,6 @@
"promptNotSet": "Prompt Not Set",
"promptNotSetDesc": "Could not find prompt for this image.",
"promptSet": "Prompt Set",
"resetInitialImage": "Reset Initial Image",
"seedNotSet": "Seed Not Set",
"seedNotSetDesc": "Could not find seed for this image.",
"seedSet": "Seed Set",
@ -1402,7 +1379,6 @@
"uploadFailed": "Upload failed",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"uploadFailedUnableToLoadDesc": "Unable to load file",
"uploadInitialImage": "Upload Initial Image",
"upscalingFailed": "Upscaling Failed",
"workflowLoaded": "Workflow Loaded",
"problemRetrievingWorkflow": "Problem Retrieving Workflow",
@ -1708,7 +1684,6 @@
"workflowLibrary": "Library",
"userWorkflows": "My Workflows",
"defaultWorkflows": "Default Workflows",
"projectWorkflows": "Project Workflows",
"openWorkflow": "Open Workflow",
"uploadWorkflow": "Load from File",
"deleteWorkflow": "Delete Workflow",
@ -1721,7 +1696,6 @@
"workflowSaved": "Workflow Saved",
"noRecentWorkflows": "No Recent Workflows",
"noUserWorkflows": "No User Workflows",
"noWorkflows": "No Workflows",
"noSystemWorkflows": "No System Workflows",
"problemLoading": "Problem Loading Workflows",
"loading": "Loading Workflows",

View File

@ -118,14 +118,7 @@
"advancedOptions": "Opzioni avanzate",
"free": "Libero",
"or": "o",
"preferencesLabel": "Preferenze",
"red": "Rosso",
"aboutHeading": "Possiedi il tuo potere creativo",
"aboutDesc": "Utilizzi Invoke per lavoro? Guarda qui:",
"localSystem": "Sistema locale",
"green": "Verde",
"blue": "Blu",
"alpha": "Alfa"
"preferencesLabel": "Preferenze"
},
"gallery": {
"generations": "Generazioni",
@ -528,8 +521,7 @@
"customConfigFileLocation": "Posizione del file di configurazione personalizzato",
"vaePrecision": "Precisione VAE",
"noModelSelected": "Nessun modello selezionato",
"conversionNotSupported": "Conversione non supportata",
"configFile": "File di configurazione"
"conversionNotSupported": "Conversione non supportata"
},
"parameters": {
"images": "Immagini",
@ -554,7 +546,7 @@
"upscaleImage": "Amplia Immagine",
"scale": "Scala",
"otherOptions": "Altre opzioni",
"seamlessTiling": "Piastrella senza giunte",
"seamlessTiling": "Piastrella senza cuciture",
"hiresOptim": "Ottimizzazione alta risoluzione",
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
"codeformerFidelity": "Fedeltà",
@ -600,8 +592,8 @@
"hidePreview": "Nascondi l'anteprima",
"showPreview": "Mostra l'anteprima",
"noiseSettings": "Rumore",
"seamlessXAxis": "Piastrella senza giunte Asse X",
"seamlessYAxis": "Piastrella senza giunte Asse Y",
"seamlessXAxis": "Piastrella senza cucitura Asse X",
"seamlessYAxis": "Piastrella senza cucitura Asse Y",
"scheduler": "Campionatore",
"boundingBoxWidth": "Larghezza riquadro di delimitazione",
"boundingBoxHeight": "Altezza riquadro di delimitazione",
@ -668,9 +660,7 @@
"lockAspectRatio": "Blocca proporzioni",
"swapDimensions": "Scambia dimensioni",
"aspect": "Aspetto",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
"boxBlur": "Box",
"gaussianBlur": "Gaussian"
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)"
},
"settings": {
"models": "Modelli",
@ -804,15 +794,13 @@
"invalidUpload": "Caricamento non valido",
"problemDeletingWorkflow": "Problema durante l'eliminazione del flusso di lavoro",
"workflowDeleted": "Flusso di lavoro eliminato",
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
"resetInitialImage": "Reimposta l'immagine iniziale",
"uploadInitialImage": "Carica l'immagine iniziale"
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro"
},
"tooltip": {
"feature": {
"prompt": "Questo è il campo del prompt. Il prompt include oggetti di generazione e termini stilistici. Puoi anche aggiungere il peso (importanza del token) nel prompt, ma i comandi e i parametri dell'interfaccia a linea di comando non funzioneranno.",
"gallery": "Galleria visualizza le generazioni dalla cartella degli output man mano che vengono create. Le impostazioni sono memorizzate all'interno di file e accessibili dal menu contestuale.",
"other": "Queste opzioni abiliteranno modalità di elaborazione alternative per Invoke. 'Piastrella senza giunte' creerà immagini piastrellabili senza giunture. 'Ottimizzazione Alta risoluzione' è la generazione in due passaggi con 'Immagine a Immagine': usa questa impostazione quando vuoi un'immagine più grande e più coerente senza artefatti. Ci vorrà più tempo del solito 'Testo a Immagine'.",
"other": "Queste opzioni abiliteranno modalità di elaborazione alternative per Invoke. 'Piastrella senza cuciture' creerà modelli ripetuti nell'output. 'Ottimizzazione Alta risoluzione' è la generazione in due passaggi con 'Immagine a Immagine': usa questa impostazione quando vuoi un'immagine più grande e più coerente senza artefatti. Ci vorrà più tempo del solito 'Testo a Immagine'.",
"seed": "Il valore del Seme influenza il rumore iniziale da cui è formata l'immagine. Puoi usare i semi già esistenti dalle immagini precedenti. 'Soglia del rumore' viene utilizzato per mitigare gli artefatti a valori CFG elevati (provare l'intervallo 0-10) e Perlin per aggiungere il rumore Perlin durante la generazione: entrambi servono per aggiungere variazioni ai risultati.",
"variations": "Prova una variazione con un valore compreso tra 0.1 e 1.0 per modificare il risultato per un dato seme. Variazioni interessanti del seme sono comprese tra 0.1 e 0.3.",
"upscale": "Utilizza ESRGAN per ingrandire l'immagine subito dopo la generazione.",
@ -911,8 +899,7 @@
"loadMore": "Carica altro",
"mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
"createIssue": "Segnala un problema",
"about": "Informazioni"
"createIssue": "Segnala un problema"
},
"ui": {
"hideProgressImages": "Nascondi avanzamento immagini",
@ -1184,7 +1171,7 @@
"depthMidas": "Profondità (Midas)",
"enableControlnet": "Abilita ControlNet",
"detectResolution": "Rileva risoluzione",
"controlMode": "Controllo",
"controlMode": "Modalità Controllo",
"cannyDescription": "Canny rilevamento bordi",
"depthZoe": "Profondità (Zoe)",
"autoConfigure": "Configura automaticamente il processore",
@ -1200,13 +1187,13 @@
"ipAdapterModel": "Modello Adattatore",
"resetControlImage": "Reimposta immagine di controllo",
"f": "F",
"h": "A",
"h": "H",
"prompt": "Prompt",
"openPoseDescription": "Stima della posa umana utilizzando Openpose",
"resizeMode": "Ridimensionamento",
"resizeMode": "Modalità ridimensionamento",
"weight": "Peso",
"selectModel": "Seleziona un modello",
"w": "L",
"w": "W",
"processor": "Processore",
"none": "Nessuno",
"pidiDescription": "Elaborazione immagini PIDI",
@ -1228,7 +1215,7 @@
"hedDescription": "Rilevamento dei bordi nidificati olisticamente",
"setControlImageDimensions": "Imposta le dimensioni dell'immagine di controllo su L/A",
"resetIPAdapterImage": "Reimposta immagine Adattatore IP",
"handAndFace": "Mani e volti",
"handAndFace": "Mano e faccia",
"enableIPAdapter": "Abilita Adattatore IP",
"maxFaces": "Numero massimo di volti",
"addT2IAdapter": "Aggiungi $t(common.t2iAdapter)",
@ -1242,7 +1229,7 @@
"controlAdapter_other": "Adattatori di Controllo",
"megaControl": "Mega ControlNet",
"minConfidence": "Confidenza minima",
"scribble": "Scarabocchio",
"scribble": "Scribble",
"amult": "Angolo di illuminazione",
"coarse": "Approssimativo",
"resizeSimple": "Ridimensiona (semplice)"

View File

@ -1,4 +1,4 @@
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { Box } from '@chakra-ui/react';
import { useSocketIO } from 'app/hooks/useSocketIO';
import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@ -8,6 +8,7 @@ import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { useGlobalModifiersInit } from 'common/hooks/useGlobalModifiers';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';

View File

@ -1,12 +1,11 @@
import { Button, Flex, Heading, Link, Text, useToast } from '@invoke-ai/ui-library';
import { Flex, Heading, Link, useToast } from '@chakra-ui/react';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvText } from 'common/components/InvText/wrapper';
import newGithubIssueUrl from 'new-github-issue-url';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
PiArrowCounterClockwiseBold,
PiArrowSquareOutBold,
PiCopyBold,
} from 'react-icons/pi';
import { FaCopy, FaExternalLinkAlt } from 'react-icons/fa';
import { FaArrowRotateLeft } from 'react-icons/fa6';
import { serializeError } from 'serialize-error';
type Props = {
@ -63,24 +62,24 @@ const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
justifyContent="space-between"
alignItems="center"
>
<Text fontWeight="semibold" color="error.400">
<InvText fontWeight="semibold" color="error.400">
{error.name}: {error.message}
</Text>
</InvText>
</Flex>
<Flex gap={4}>
<Button
leftIcon={<PiArrowCounterClockwiseBold />}
<InvButton
leftIcon={<FaArrowRotateLeft />}
onClick={resetErrorBoundary}
>
{t('accessibility.resetUI')}
</Button>
<Button leftIcon={<PiCopyBold />} onClick={handleCopy}>
</InvButton>
<InvButton leftIcon={<FaCopy />} onClick={handleCopy}>
{t('common.copyError')}
</Button>
</InvButton>
<Link href={url} isExternal>
<Button leftIcon={<PiArrowSquareOutBold />}>
<InvButton leftIcon={<FaExternalLinkAlt />}>
{t('accessibility.createIssue')}
</Button>
</InvButton>
</Link>
</Flex>
</Flex>

View File

@ -1,16 +1,12 @@
import '@fontsource-variable/inter';
import 'overlayscrollbars/overlayscrollbars.css';
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
import {
ChakraProvider,
DarkMode,
extendTheme,
theme as _theme,
TOAST_OPTIONS,
} from '@invoke-ai/ui-library';
import { ChakraProvider, extendTheme } from '@chakra-ui/react';
import type { ReactNode } from 'react';
import { memo, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme, TOAST_OPTIONS } from 'theme/theme';
type ThemeLocaleProviderProps = {
children: ReactNode;
@ -23,7 +19,7 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
const theme = useMemo(() => {
return extendTheme({
..._theme,
...invokeAITheme,
direction,
});
}, [direction]);
@ -34,7 +30,7 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return (
<ChakraProvider theme={theme} toastOptions={TOAST_OPTIONS}>
<DarkMode>{children}</DarkMode>
{children}
</ChakraProvider>
);
}

View File

@ -1,4 +1,4 @@
import { useToast } from '@invoke-ai/ui-library';
import { useToast } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import type { MakeToastArg } from 'features/system/util/makeToast';

View File

@ -45,7 +45,7 @@ export const useSocketIO = () => {
const socketOptions = useMemo(() => {
const options: Partial<ManagerOptions & SocketOptions> = {
timeout: 60000,
path: baseUrl ? '/ws/socket.io' : `${window.location.pathname}ws/socket.io`,
path: '/ws/socket.io',
autoConnect: false, // achtung! removing this breaks the dynamic middleware
forceNew: true,
};
@ -56,7 +56,7 @@ export const useSocketIO = () => {
}
return { ...options, ...addlSocketOptions };
}, [authToken, addlSocketOptions, baseUrl]);
}, [authToken, addlSocketOptions]);
useEffect(() => {
if ($isSocketInitialized.get()) {

View File

@ -1,10 +1,11 @@
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import { createStandaloneToast } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { queueApi } from 'services/api/endpoints/queue';
import { theme, TOAST_OPTIONS } from 'theme/theme';
import { startAppListening } from '..';

View File

@ -1,4 +1,4 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import type { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {

View File

@ -36,9 +36,9 @@ export const addModelSelectedListener = () => {
const newModel = result.data;
const newBaseModel = newModel.base_model;
const { base_model } = newModel;
const didBaseModelChange =
state.generation.model?.base_model !== newBaseModel;
state.generation.model?.base_model !== base_model;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@ -46,7 +46,7 @@ export const addModelSelectedListener = () => {
// handle incompatible loras
forEach(state.lora.loras, (lora, id) => {
if (lora.base_model !== newBaseModel) {
if (lora.base_model !== base_model) {
dispatch(loraRemoved(id));
modelsCleared += 1;
}
@ -54,14 +54,14 @@ export const addModelSelectedListener = () => {
// handle incompatible vae
const { vae } = state.generation;
if (vae && vae.base_model !== newBaseModel) {
if (vae && vae.base_model !== base_model) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base_model !== newBaseModel) {
if (ca.model?.base_model !== base_model) {
dispatch(
controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false })
);

View File

@ -46,14 +46,14 @@ export const addDynamicPromptsListener = () => {
if (cachedPrompts) {
dispatch(promptsChanged(cachedPrompts.prompts));
dispatch(parsingErrorChanged(cachedPrompts.error));
return;
}
if (!getShouldProcessPrompt(state.generation.positivePrompt)) {
if (state.dynamicPrompts.isLoading) {
dispatch(isLoadingChanged(false));
}
dispatch(promptsChanged([state.generation.positivePrompt]));
dispatch(parsingErrorChanged(undefined));
dispatch(isErrorChanged(false));
return;
}
@ -78,6 +78,7 @@ export const addDynamicPromptsListener = () => {
dispatch(promptsChanged(res.prompts));
dispatch(parsingErrorChanged(res.error));
dispatch(isErrorChanged(false));
dispatch(isLoadingChanged(false));
} catch {
dispatch(isErrorChanged(true));
dispatch(isLoadingChanged(false));

View File

@ -1,4 +1,4 @@
import type { MenuItemProps } from '@invoke-ai/ui-library';
import type { MenuItemProps } from '@chakra-ui/react';
import { atom } from 'nanostores';
export type CustomStarUi = {

View File

@ -1,10 +1,5 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import {
CompositeNumberInput,
Flex,
FormControl,
FormLabel,
} from '@invoke-ai/ui-library';
import type { ChakraProps } from '@chakra-ui/react';
import { Flex } from '@chakra-ui/react';
import type { CSSProperties } from 'react';
import { memo, useCallback } from 'react';
import { RgbaColorPicker } from 'react-colorful';
@ -12,7 +7,9 @@ import type {
ColorPickerBaseProps,
RgbaColor,
} from 'react-colorful/dist/types';
import { useTranslation } from 'react-i18next';
import { InvControl } from './InvControl/InvControl';
import { InvNumberInput } from './InvNumberInput/InvNumberInput';
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor> & {
withNumberInput?: boolean;
@ -38,7 +35,6 @@ const numberInputWidth: ChakraProps['w'] = '4.2rem';
const IAIColorPicker = (props: IAIColorPickerProps) => {
const { color, onChange, withNumberInput, ...rest } = props;
const { t } = useTranslation();
const handleChangeR = useCallback(
(r: number) => onChange({ ...color, r }),
[color, onChange]
@ -65,9 +61,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
/>
{withNumberInput && (
<Flex>
<FormControl>
<FormLabel>{t('common.red')}</FormLabel>
<CompositeNumberInput
<InvControl label="Red">
<InvNumberInput
value={color.r}
onChange={handleChangeR}
min={0}
@ -76,10 +71,9 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
w={numberInputWidth}
defaultValue={90}
/>
</FormControl>
<FormControl>
<FormLabel>{t('common.green')}</FormLabel>
<CompositeNumberInput
</InvControl>
<InvControl label="Green">
<InvNumberInput
value={color.g}
onChange={handleChangeG}
min={0}
@ -88,10 +82,9 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
w={numberInputWidth}
defaultValue={90}
/>
</FormControl>
<FormControl>
<FormLabel>{t('common.blue')}</FormLabel>
<CompositeNumberInput
</InvControl>
<InvControl label="Blue">
<InvNumberInput
value={color.b}
onChange={handleChangeB}
min={0}
@ -100,10 +93,9 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
w={numberInputWidth}
defaultValue={255}
/>
</FormControl>
<FormControl>
<FormLabel>{t('common.alpha')}</FormLabel>
<CompositeNumberInput
</InvControl>
<InvControl label="Alpha">
<InvNumberInput
value={color.a}
onChange={handleChangeA}
step={0.1}
@ -112,7 +104,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
w={numberInputWidth}
defaultValue={1}
/>
</FormControl>
</InvControl>
</Flex>
)}
</Flex>

View File

@ -1,5 +1,9 @@
import type { ChakraProps, FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import type {
ChakraProps,
FlexProps,
SystemStyleObject,
} from '@chakra-ui/react';
import { Flex, Icon, Image } from '@chakra-ui/react';
import {
IAILoadingImageFallback,
IAINoContentFallback,
@ -18,16 +22,16 @@ import type {
SyntheticEvent,
} from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
import { FaImage, FaUpload } from 'react-icons/fa';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
import SelectionOverlay from './SelectionOverlay';
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
const defaultUploadElement = <Icon as={FaUpload} boxSize={16} />;
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined;

View File

@ -1,8 +1,9 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@chakra-ui/react';
import type { MouseEvent, ReactElement } from 'react';
import { memo, useMemo } from 'react';
import { InvIconButton } from './InvIconButton/InvIconButton';
type Props = {
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
tooltip: string;
@ -25,7 +26,7 @@ const IAIDndImageIcon = (props: Props) => {
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
filter: 'drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))',
},
...styleOverrides,
}),
@ -33,7 +34,7 @@ const IAIDndImageIcon = (props: Props) => {
);
return (
<IconButton
<InvIconButton
onClick={onClick}
aria-label={tooltip}
tooltip={tooltip}

View File

@ -1,5 +1,5 @@
import type { BoxProps } from '@invoke-ai/ui-library';
import { Box } from '@invoke-ai/ui-library';
import type { BoxProps } from '@chakra-ui/react';
import { Box } from '@chakra-ui/react';
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDraggableData } from 'features/dnd/types';
import { memo, useRef } from 'react';

View File

@ -1,10 +1,10 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { Box, Flex } from '@chakra-ui/react';
import type { AnimationProps } from 'framer-motion';
import { motion } from 'framer-motion';
import type { ReactNode } from 'react';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { v4 as uuidv4 } from 'uuid';
type Props = {
isOver: boolean;
label?: ReactNode;
@ -23,8 +23,7 @@ const exit: AnimationProps['exit'] = {
};
const IAIDropOverlay = (props: Props) => {
const { t } = useTranslation();
const { isOver, label = t('gallery.drop') } = props;
const { isOver, label = 'Drop' } = props;
const motionId = useRef(uuidv4());
return (
<motion.div

View File

@ -1,4 +1,4 @@
import { Box } from '@invoke-ai/ui-library';
import { Box } from '@chakra-ui/react';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';

View File

@ -1,5 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Skeleton } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@chakra-ui/react';
import { Box, Skeleton } from '@chakra-ui/react';
import { memo } from 'react';
const skeletonStyles: SystemStyleObject = {

View File

@ -1,9 +1,11 @@
import type { As, ChakraProps, FlexProps } from '@invoke-ai/ui-library';
import { Flex, Icon, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import type { As, FlexProps, StyleProps } from '@chakra-ui/react';
import { Flex, Icon, Skeleton, Spinner } from '@chakra-ui/react';
import { memo, useMemo } from 'react';
import { PiImageBold } from 'react-icons/pi';
import { FaImage } from 'react-icons/fa';
import type { ImageDTO } from 'services/api/types';
import { InvText } from './InvText/wrapper';
type Props = { image: ImageDTO | undefined };
export const IAILoadingImageFallback = memo((props: Props) => {
@ -37,11 +39,11 @@ IAILoadingImageFallback.displayName = 'IAILoadingImageFallback';
type IAINoImageFallbackProps = FlexProps & {
label?: string;
icon?: As | null;
boxSize?: ChakraProps['boxSize'];
boxSize?: StyleProps['boxSize'];
};
export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
const { icon = PiImageBold, boxSize = 16, sx, ...rest } = props;
const { icon = FaImage, boxSize = 16, sx, ...rest } = props;
const styles = useMemo(
() => ({
@ -64,9 +66,9 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
<Flex sx={styles} {...rest}>
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && (
<Text textAlign="center" fontSize="md">
<InvText textAlign="center" fontSize="md">
{props.label}
</Text>
</InvText>
)}
</Flex>
);
@ -100,7 +102,7 @@ export const IAINoContentFallbackWithSpinner = memo(
return (
<Flex sx={styles} {...rest}>
<Spinner size="xl" />
{props.label && <Text textAlign="center">{props.label}</Text>}
{props.label && <InvText textAlign="center">{props.label}</InvText>}
</Flex>
);
}

View File

@ -0,0 +1,147 @@
import { Divider, Flex, Image, Portal } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvHeading } from 'common/components/InvHeading/wrapper';
import {
InvPopover,
InvPopoverBody,
InvPopoverCloseButton,
InvPopoverContent,
InvPopoverTrigger,
} from 'common/components/InvPopover/wrapper';
import { InvText } from 'common/components/InvText/wrapper';
import { merge, omit } from 'lodash-es';
import type { ReactElement } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaExternalLinkAlt } from 'react-icons/fa';
import type { Feature, PopoverData } from './constants';
import { OPEN_DELAY, POPOVER_DATA, POPPER_MODIFIERS } from './constants';
type Props = {
feature: Feature;
inPortal?: boolean;
children: ReactElement;
};
const IAIInformationalPopover = ({
feature,
children,
inPortal = true,
...rest
}: Props) => {
const shouldEnableInformationalPopovers = useAppSelector(
(s) => s.system.shouldEnableInformationalPopovers
);
const data = useMemo(() => POPOVER_DATA[feature], [feature]);
const popoverProps = useMemo(
() => merge(omit(data, ['image', 'href', 'buttonLabel']), rest),
[data, rest]
);
if (!shouldEnableInformationalPopovers) {
return children;
}
return (
<InvPopover
isLazy
closeOnBlur={false}
trigger="hover"
variant="informational"
openDelay={OPEN_DELAY}
modifiers={POPPER_MODIFIERS}
placement="top"
{...popoverProps}
>
<InvPopoverTrigger>{children}</InvPopoverTrigger>
{inPortal ? (
<Portal>
<PopoverContent data={data} feature={feature} />
</Portal>
) : (
<PopoverContent data={data} feature={feature} />
)}
</InvPopover>
);
};
export default memo(IAIInformationalPopover);
type PopoverContentProps = {
data?: PopoverData;
feature: Feature;
};
const PopoverContent = ({ data, feature }: PopoverContentProps) => {
const { t } = useTranslation();
const heading = useMemo<string | undefined>(
() => t(`popovers.${feature}.heading`),
[feature, t]
);
const paragraphs = useMemo<string[]>(
() =>
t(`popovers.${feature}.paragraphs`, {
returnObjects: true,
}) ?? [],
[feature, t]
);
const handleClick = useCallback(() => {
if (!data?.href) {
return;
}
window.open(data.href);
}, [data?.href]);
return (
<InvPopoverContent w={96}>
<InvPopoverCloseButton />
<InvPopoverBody>
<Flex gap={2} flexDirection="column" alignItems="flex-start">
{heading && (
<>
<InvHeading size="sm">{heading}</InvHeading>
<Divider />
</>
)}
{data?.image && (
<>
<Image
objectFit="contain"
maxW="60%"
maxH="60%"
backgroundColor="white"
src={data.image}
alt="Optional Image"
/>
<Divider />
</>
)}
{paragraphs.map((p) => (
<InvText key={p}>{p}</InvText>
))}
{data?.href && (
<>
<Divider />
<InvButton
pt={1}
onClick={handleClick}
leftIcon={<FaExternalLinkAlt />}
alignSelf="flex-end"
variant="link"
>
{t('common.learnMore') ?? heading}
</InvButton>
</>
)}
</Flex>
</InvPopoverBody>
</InvPopoverContent>
);
};

View File

@ -1,4 +1,4 @@
import type { PopoverProps } from '@invoke-ai/ui-library';
import type { PopoverProps } from '@chakra-ui/react';
export type Feature =
| 'clipSkip'

View File

@ -1,4 +1,4 @@
import { Badge, Flex } from '@invoke-ai/ui-library';
import { Badge, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';

View File

@ -1,4 +1,4 @@
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
import { Box, Flex, Heading } from '@chakra-ui/react';
import type { AnimationProps } from 'framer-motion';
import { motion } from 'framer-motion';
import { memo } from 'react';

View File

@ -1,147 +0,0 @@
import {
Button,
Divider,
Flex,
Heading,
Image,
Popover,
PopoverBody,
PopoverCloseButton,
PopoverContent,
PopoverTrigger,
Portal,
Text,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { merge, omit } from 'lodash-es';
import type { ReactElement } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowSquareOutBold } from 'react-icons/pi';
import type { Feature, PopoverData } from './constants';
import { OPEN_DELAY, POPOVER_DATA, POPPER_MODIFIERS } from './constants';
type Props = {
feature: Feature;
inPortal?: boolean;
children: ReactElement;
};
export const InformationalPopover = memo(
({ feature, children, inPortal = true, ...rest }: Props) => {
const shouldEnableInformationalPopovers = useAppSelector(
(s) => s.system.shouldEnableInformationalPopovers
);
const data = useMemo(() => POPOVER_DATA[feature], [feature]);
const popoverProps = useMemo(
() => merge(omit(data, ['image', 'href', 'buttonLabel']), rest),
[data, rest]
);
if (!shouldEnableInformationalPopovers) {
return children;
}
return (
<Popover
isLazy
closeOnBlur={false}
trigger="hover"
variant="informational"
openDelay={OPEN_DELAY}
modifiers={POPPER_MODIFIERS}
placement="top"
{...popoverProps}
>
<PopoverTrigger>{children}</PopoverTrigger>
{inPortal ? (
<Portal>
<Content data={data} feature={feature} />
</Portal>
) : (
<Content data={data} feature={feature} />
)}
</Popover>
);
}
);
InformationalPopover.displayName = 'InformationalPopover';
type ContentProps = {
data?: PopoverData;
feature: Feature;
};
const Content = ({ data, feature }: ContentProps) => {
const { t } = useTranslation();
const heading = useMemo<string | undefined>(
() => t(`popovers.${feature}.heading`),
[feature, t]
);
const paragraphs = useMemo<string[]>(
() =>
t(`popovers.${feature}.paragraphs`, {
returnObjects: true,
}) ?? [],
[feature, t]
);
const handleClick = useCallback(() => {
if (!data?.href) {
return;
}
window.open(data.href);
}, [data?.href]);
return (
<PopoverContent w={96}>
<PopoverCloseButton />
<PopoverBody>
<Flex gap={2} flexDirection="column" alignItems="flex-start">
{heading && (
<>
<Heading size="sm">{heading}</Heading>
<Divider />
</>
)}
{data?.image && (
<>
<Image
objectFit="contain"
maxW="60%"
maxH="60%"
backgroundColor="white"
src={data.image}
alt="Optional Image"
/>
<Divider />
</>
)}
{paragraphs.map((p) => (
<Text key={p}>{p}</Text>
))}
{data?.href && (
<>
<Divider />
<Button
pt={1}
onClick={handleClick}
leftIcon={<PiArrowSquareOutBold />}
alignSelf="flex-end"
variant="link"
>
{t('common.learnMore') ?? heading}
</Button>
</>
)}
</Flex>
</PopoverBody>
</PopoverContent>
);
};

View File

@ -0,0 +1,70 @@
import type { Meta, StoryObj } from '@storybook/react';
import { InvText } from 'common/components/InvText/wrapper';
import { InvAccordionButton } from './InvAccordionButton';
import type { InvAccordionProps } from './types';
import { InvAccordion, InvAccordionItem, InvAccordionPanel } from './wrapper';
const meta: Meta<typeof InvAccordion> = {
title: 'Primitives/InvAccordion',
tags: ['autodocs'],
component: InvAccordion,
args: {
colorScheme: 'base',
},
};
export default meta;
type Story = StoryObj<typeof InvAccordion>;
const Component = (props: InvAccordionProps) => {
return (
<InvAccordion {...props} defaultIndex={[0]} allowMultiple>
<InvAccordionItem>
<InvAccordionButton badges={['and', 'i', 'said']}>
Section 1 title
</InvAccordionButton>
<InvAccordionPanel p={4}>
<InvText>
25 years and my life is still Tryin&apos; to get up that great big
hill of hope For a destination I realized quickly when I knew I
should That the world was made up of this brotherhood of man For
whatever that means
</InvText>
</InvAccordionPanel>
</InvAccordionItem>
<InvAccordionItem>
<InvAccordionButton badges={['heeeyyyyyy']}>
Section 1 title
</InvAccordionButton>
<InvAccordionPanel p={4}>
<InvText>
And so I cry sometimes when I&apos;m lying in bed Just to get it all
out what&apos;s in my head And I, I am feeling a little peculiar And
so I wake in the morning and I step outside And I take a deep breath
and I get real high And I scream from the top of my lungs
&quot;What&apos;s going on?&quot;
</InvText>
</InvAccordionPanel>
</InvAccordionItem>
<InvAccordionItem>
<InvAccordionButton badges={["what's", 'goin', 'on', '?']}>
Section 2 title
</InvAccordionButton>
<InvAccordionPanel p={4}>
<InvText>
And I say, hey-ey-ey Hey-ey-ey I said &quot;Hey, a-what&apos;s going
on?&quot; And I say, hey-ey-ey Hey-ey-ey I said &quot;Hey,
a-what&apos;s going on?&quot;
</InvText>
</InvAccordionPanel>
</InvAccordionItem>
</InvAccordion>
);
};
export const Default: Story = {
render: Component,
};

View File

@ -0,0 +1,31 @@
import {
AccordionButton as ChakraAccordionButton,
Spacer,
} from '@chakra-ui/react';
import { InvBadge } from 'common/components/InvBadge/wrapper';
import { truncate } from 'lodash-es';
import { useMemo } from 'react';
import type { InvAccordionButtonProps } from './types';
import { InvAccordionIcon } from './wrapper';
export const InvAccordionButton = (props: InvAccordionButtonProps) => {
const { children, badges: _badges, ...rest } = props;
const badges = useMemo<string[] | undefined>(
() =>
_badges?.map((b) => truncate(String(b), { length: 24, omission: '...' })),
[_badges]
);
return (
<ChakraAccordionButton {...rest}>
{children}
<Spacer />
{badges?.map((b, i) => (
<InvBadge key={`${b}.${i}`} colorScheme="invokeBlue">
{b}
</InvBadge>
))}
<InvAccordionIcon />
</ChakraAccordionButton>
);
};

View File

@ -0,0 +1,71 @@
import { accordionAnatomy as parts } from '@chakra-ui/anatomy';
import {
createMultiStyleConfigHelpers,
defineStyle,
} from '@chakra-ui/styled-system';
const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(parts.keys);
const invokeAIContainer = defineStyle({
border: 'none',
bg: 'base.850',
borderRadius: 'base',
':has(&div &button:hover)': { bg: 'base.800' },
transitionProperty: 'common',
transitionDuration: '0.1s',
});
const invokeAIButton = defineStyle((_props) => {
return {
gap: 2,
fontWeight: 'semibold',
fontSize: 'sm',
border: 'none',
borderRadius: 'base',
color: 'base.300',
_hover: {},
_expanded: {
borderBottomRadius: 'none',
},
};
});
const invokeAIPanel = defineStyle((props) => {
const { colorScheme: c } = props;
return {
bg: `${c}.800`,
borderRadius: 'base',
p: 0,
transitionProperty: 'common',
transitionDuration: '0.1s',
};
});
const invokeAIIcon = defineStyle({
ms: 2,
});
const invokeAI = definePartsStyle((props) => ({
container: invokeAIContainer,
button: invokeAIButton(props),
panel: invokeAIPanel(props),
icon: invokeAIIcon,
}));
const baseStyle = definePartsStyle(() => ({
root: {
display: 'flex',
flexDirection: 'column',
gap: 4,
},
}));
export const accordionTheme = defineMultiStyleConfig({
baseStyle,
variants: { invokeAI },
defaultProps: {
variant: 'invokeAI',
colorScheme: 'base',
},
});

View File

@ -0,0 +1,11 @@
import type { AccordionButtonProps as ChakraAccordionButtonProps } from '@chakra-ui/react';
export type {
AccordionIconProps as InvAccordionIconProps,
AccordionItemProps as InvAccordionItemProps,
AccordionPanelProps as InvAccordionPanelProps,
AccordionProps as InvAccordionProps,
} from '@chakra-ui/react';
export type InvAccordionButtonProps = ChakraAccordionButtonProps & {
badges?: (string | number)[];
};

View File

@ -0,0 +1,6 @@
export {
Accordion as InvAccordion,
AccordionIcon as InvAccordionIcon,
AccordionItem as InvAccordionItem,
AccordionPanel as InvAccordionPanel,
} from '@chakra-ui/react';

View File

@ -0,0 +1,4 @@
/**
* AlertDialog is a chakra Modal internally and uses those props.
*/
export type { AlertDialogProps as InvAlertDialogProps } from '@chakra-ui/react';

View File

@ -0,0 +1,9 @@
export {
AlertDialog as InvAlertDialog,
AlertDialogBody as InvAlertDialogBody,
AlertDialogCloseButton as InvAlertDialogCloseButton,
AlertDialogContent as InvAlertDialogContent,
AlertDialogFooter as InvAlertDialogFooter,
AlertDialogHeader as InvAlertDialogHeader,
AlertDialogOverlay as InvAlertDialogOverlay,
} from '@chakra-ui/react';

Some files were not shown because too many files have changed in this diff Show More