Compare commits

..

146 Commits

Author SHA1 Message Date
bbb48c5475 Gitignore aider 2024-05-28 15:42:34 -04:00
e68c49167a aider: Fixed CPU memory issues by deleting large objects and forcing garbage collection.
# Aider chat conversation:

USER: `invokeai/backend/model_manager/convert_ckpt_to_diffusers.py` Im having issues with cpu memory, not gpu memory

ASSISTANT: To address CPU memory issues, we can make the following changes:

1. Use `del` to explicitly delete large objects when they are no longer needed.
2. Use `gc.collect()` to force garbage collection after deleting large objects.

Here are the *SEARCH/REPLACE* blocks:

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    vae = AutoencoderKL(**vae_config)
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        vae.to(precision)
        torch.cuda.empty_cache()
=======
    vae = AutoencoderKL(**vae_config)
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        del converted_vae_checkpoint  # Free memory
        import gc
        gc.collect()
        vae.to(precision)
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
=======
    pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
    with torch.no_grad():
        del kwargs  # Free memory
        import gc
        gc.collect()
        pipe = pipe.to(precision)
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
=======
    pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
    with torch.no_grad():
        del kwargs  # Free memory
        import gc
        gc.collect()
        pipe = pipe.to(precision)
>>>>>>> REPLACE
```
2024-05-28 15:05:36 -04:00
46950a9bd0 aider: Improved memory usage efficiency by avoiding intermediate value storage and clearing the cache.
# Aider chat conversation:

USER: `invokeai/backend/model_manager/convert_ckpt_to_diffusers.py` Could you improve the memory usage efficiency of this code?

ASSISTANT: To improve the memory usage efficiency of the code, we can make the following changes:

1. Use `with torch.no_grad()` to avoid storing intermediate values for backpropagation.
2. Move the model to the desired device only when necessary.
3. Clear the cache after moving the model to the desired device.

Here are the *SEARCH/REPLACE* blocks:

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    vae.load_state_dict(converted_vae_checkpoint)
    vae.to(precision)
=======
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        vae.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = pipe.to(precision)
=======
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = pipe.to(precision)
=======
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```
2024-05-28 15:03:43 -04:00
df91d1b849 Update TI handling for compatibility with transformers 4.40.0 (#6449)
## Summary

- Updated the documentation for `TextualInversionManager`
- Updated the `self.tokenizer.model_max_length` access to work with the
latest transformers version. Thanks to @skunkworxdark for looking into
this here:
https://github.com/invoke-ai/InvokeAI/issues/6445#issuecomment-2133098342

## Related Issues / Discussions

Closes #6445 

## QA Instructions

I tested with `transformers==4.41.1`, and compared the results against a
recent InvokeAI version before updating tranformers - no change, as
expected.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-05-28 08:32:02 -04:00
829b9ad66b Add a callout about the hackiness of dropping tokens in the TextualInversionManager. 2024-05-28 05:11:54 -07:00
3aa1c8d3a8 Update TextualInversionManager for compatibility with the latest transformers release. See https://github.com/invoke-ai/InvokeAI/issues/6445. 2024-05-28 05:11:54 -07:00
994c61b67a Add docs to TextualInversionManager and improve types. No changes to functionality. 2024-05-28 05:11:54 -07:00
21aa42627b feat(events): add dynamic invocation & result validators
This is required to get these event fields to deserialize correctly. If omitted, pydantic uses `BaseInvocation`/`BaseInvocationOutput`, which is not correct.

This is similar to the workaround in the `Graph` and `GraphExecutionState` classes where we need to fanagle pydantic with manual validation handling.
2024-05-28 05:11:37 -07:00
a4f88ff834 feat(events): add __event_name__ as ClassVar to EventBase
This improves types for event consumers that need to access the event name.
2024-05-28 05:11:37 -07:00
ddff9b4584 fix(events): typing for download event handler 2024-05-27 11:13:47 +10:00
b50133d5e1 feat(events): register event schemas
This allows for events to be dispatched using dicts as payloads, and have the dicts validated as pydantic schemas.
2024-05-27 11:13:47 +10:00
5388f5a817 fix(ui): edit variant for main models only
Closes #6444
2024-05-27 11:02:00 +10:00
27a3eb15f8 feat(ui): update event types 2024-05-27 10:17:02 +10:00
4b2d57a5e0 chore(ui): typegen
Note about the huge diff: I had a different version of pydantic installed at some point, which slightly altered a _ton_ of schema components. This typegen was done on the correct version of pydantic and un-does those alterations, in addition to the intentional changes to event models.
2024-05-27 10:17:02 +10:00
bbb90ff949 feat(events): restore whole invocation to event payloads
Removing this is a breaking API change - some consumers of the events need the whole invocation. Didn't realize that until now.
2024-05-27 10:17:02 +10:00
9d9801b2c2 feat(events): stronger generic typing for event registration 2024-05-27 10:17:02 +10:00
8498d4344b docs: update docstrings in sockets.py 2024-05-27 09:06:02 +10:00
dfad37a262 docs: update comments & docstrings 2024-05-27 09:06:02 +10:00
89dede7bad feat(ui): simplify client sio redux actions
- Add a simple helper to create socket actions in a less error-prone way
- Organize and tidy sio files
2024-05-27 09:06:02 +10:00
60784a4361 feat(ui): update client for removal of session events 2024-05-27 09:06:02 +10:00
3d8774d295 chore(ui): typegen 2024-05-27 09:06:02 +10:00
084cf26ed6 refactor: remove all session events
There's no longer any need for session-scoped events now that we have the session queue. Session started/completed/canceled map 1-to-1 to queue item status events, but queue item status events also have an event for failed state.

We can simplify queue and processor handling substantially by removing session events and instead using queue item events.

- Remove the session-scoped events entirely.
- Remove all event handling from session queue. The processor still needs to respond to some events from the queue: `QueueClearedEvent`, `BatchEnqueuedEvent` and `QueueItemStatusChangedEvent`.
- Pass an `is_canceled` callback to the invocation context instead of the cancel event
- Update processor logic to ensure the local instance of the current queue item is synced with the instance in the database. This prevents race conditions and ensures lifecycle callback do not get stale callbacks.
- Update docstrings and comments
- Add `complete_queue_item` method to session queue service as an explicit way to mark a queue item as successfully completed. Previously, the queue listened for session complete events to do this.

Closes #6442
2024-05-27 09:06:02 +10:00
8592f5c6e1 feat(events): move event sets outside sio class
This lets the event sets be consumed programmatically.
2024-05-27 09:06:02 +10:00
368127bd25 feat(events): register_events supports single event 2024-05-27 09:06:02 +10:00
c0aabcd8ea tidy(events): use tuple index access for event payloads 2024-05-27 09:06:02 +10:00
ed6c716ddc fix(mm): emit correct event when model load complete 2024-05-27 09:06:02 +10:00
eaf67b2150 feat(ui): add logging for session events 2024-05-27 09:06:02 +10:00
575943d0ad fix(processor): move session started event to session runner 2024-05-27 09:06:02 +10:00
25d1d2b591 tidy(processor): use separate handlers for each event type
Just a bit clearer without needing `isinstance` checks.
2024-05-27 09:06:02 +10:00
39415428de chore(ui): typegen 2024-05-27 09:06:02 +10:00
64d553f72c feat(events): restore temp handling of user/project 2024-05-27 09:06:02 +10:00
5b390bb11c tests: clean up tests after events changes 2024-05-27 09:06:02 +10:00
a9f773c03c fix(mm): port changes into new model_install_common file
Some subtle changes happened between this PR's last update and now. Bring them into the file.
2024-05-27 09:06:02 +10:00
585feccf82 fix(ui): update event handling to match new types 2024-05-27 09:06:02 +10:00
cbd3b15cae chore(ui): typegen 2024-05-27 09:06:02 +10:00
cc56918453 tidy(ui): remove old unused session subscribe actions 2024-05-27 09:06:02 +10:00
f82df2661a docs: clarify comment in api_app 2024-05-27 09:06:02 +10:00
a1d68eb319 fix(ui): denoise percentage 2024-05-27 09:06:02 +10:00
8b5caa7e57 chore(ui): typegen 2024-05-27 09:06:02 +10:00
b3a051250f feat(api): sort socket event names for openapi schema
Deterministic ordering prevents extraneous, non-functional changes to the autogenerated types
2024-05-27 09:06:02 +10:00
0f733c42fc fix(events): fix denoise progress percentage
- Restore calculation of step percentage but in the backend instead of client
- Simplify signatures for denoise progress event callbacks
- Clean up `step_callback.py` (types, do not recreate constant matrix on every step, formatting)
2024-05-27 09:06:02 +10:00
ec4f10aed3 chore(ui): typegen 2024-05-27 09:06:02 +10:00
d97186dfc8 feat(events): remove payload registry, add method to get event classes
We don't need to use the payload schema registry. All our events are dispatched as pydantic models, which are already validated on instantiation.

We do want to add all events to the OpenAPI schema, and we referred to the payload schema registry for this. To get all events, add a simple helper to EventBase. This is functionally identical to using the schema registry.
2024-05-27 09:06:02 +10:00
18b4f1b72a feat(ui): add missing socket events 2024-05-27 09:06:02 +10:00
5cdf71b72f feat(events): add missing events
These events weren't being emitted via socket.io:
- DownloadCancelledEvent
- DownloadCompleteEvent
- DownloadErrorEvent
- DownloadProgressEvent
- DownloadStartedEvent
- ModelInstallDownloadsCompleteEvent
2024-05-27 09:06:02 +10:00
88a2340b95 feat(events): use builder pattern for download events 2024-05-27 09:06:02 +10:00
1be4cab2d9 fix(events): dump events with mode="json"
Ensures all model events are serializable.
2024-05-27 09:06:02 +10:00
567b87cc50 docs(events): update event docstrings 2024-05-27 09:06:02 +10:00
4756920282 tests: move fixtures import to conftest.py 2024-05-27 09:06:02 +10:00
a876675448 tests: update tests to use new events 2024-05-27 09:06:02 +10:00
655f62008f fix(mm): check for presence of invoker before emitting model load event
The model loader emits events. During testing, it doesn't have access to a fully-mocked events service, so the test fails when attempting to call a nonexistent method. There was a check for this previously, but I accidentally removed it. Restored.
2024-05-27 09:06:02 +10:00
300725d1dd fix(ui): correct model load event format 2024-05-27 09:06:02 +10:00
bf03127c69 fix(events): add missing __event_name__ to EventBase 2024-05-27 09:06:02 +10:00
2dc752ea83 feat(events): simplify event classes
- Remove ABCs, they do not work well with pydantic
- Remove the event type classvar - unused
- Remove clever logic to require an event name - we already get validation for this during schema registration.
- Rename event bases to all end in "Base"
2024-05-27 09:06:02 +10:00
1b9bbaa5a4 fix(events): emit bulk download events in correct room 2024-05-27 09:06:02 +10:00
3abc182b44 chore(ui): tidy after rebase 2024-05-27 09:06:02 +10:00
8d79ce94aa feat(ui): update UI to use new events
- Use OpenAPI schema for event payload types
- Update all event listeners
- Add missing events / remove old nonexistent events
2024-05-27 09:06:02 +10:00
975dc14579 chore(ui): typegen 2024-05-27 09:06:02 +10:00
9bd78823a3 refactor(events): use pydantic schemas for events
Our events handling and implementation has a couple pain points:
- Adding or removing data from event payloads requires changes wherever the events are dispatched from.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.
- Frontend types for socket events must be manually typed. This has caused several bugs.

`fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed.
- Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload.
- Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync.

This commit moves the backend over to this improved event handling setup.
2024-05-27 09:06:02 +10:00
461e857824 fix(ui): parameter not set translation 2024-05-26 08:21:06 -07:00
48db0b90e8 Bump transformers 2024-05-26 12:51:07 +10:00
c010ce49f7 Bump huggingface-hub 2024-05-26 12:51:07 +10:00
6df8b23c59 Bump transformers 2024-05-26 12:51:07 +10:00
dfe02b26c1 Bump accelerate 2024-05-26 12:51:07 +10:00
4142dc7141 Update deps to their lastest version 2024-05-26 12:51:07 +10:00
86bfcc53a3 docs: fix typo (#6395)
may noise steps -> many noise steps
2024-05-24 18:02:17 +00:00
532f82cb97 Optimize RAM to VRAM transfer (#6312)
* avoid copying model back from cuda to cpu

* handle models that don't have state dicts

* add assertions that models need a `device()` method

* do not rely on torch.nn.Module having the device() method

* apply all patches after model is on the execution device

* fix model patching in latents too

* log patched tokenizer

* closes #6375

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-05-24 17:06:09 +00:00
7437085cac fix typo (#6255) 2024-05-24 15:26:05 +00:00
e9b80cf28f fix(ui): isLocal erroneously hardcoded 2024-05-25 00:05:44 +10:00
f5a775ae4e feat(ui): toast on queue item errors, improved error descriptions
Show error toasts on queue item error events instead of invocation error events. This allows errors that occurred outside node execution to be surfaced to the user.

The error description component is updated to show the new error message if available. Commercial handling is retained, but local now uses the same component to display the error message itself.
2024-05-24 20:02:24 +10:00
50dd569411 fix(processor): race condition that could result in node errors not getting reported
I had set the cancel event at some point during troubleshooting an unrelated issue. It seemed logical that it should be set there, and didn't seem to break anything. However, this is not correct.

The cancel event should not be set in response to a queue status change event. Doing so can cause a race condition when nodes are executed very quickly.

It's possible that a previously-executed session's queue item status change event is handled after the next session starts executing. The cancel event is set and the session runner sees it aborting the session run early.

In hindsight, it doesn't make sense to set the cancel event here either. It should be set in response to user action, e.g. the user cancelled the session or cleared the queue (which implicitly cancels the current session). These events actually trigger the queue item status changed event, so if we set the cancel event here, we'd be setting it twice per cancellation.
2024-05-24 20:02:24 +10:00
125e1d7eb4 tidy: remove unnecessary whitespace changes 2024-05-24 20:02:24 +10:00
2fbe5ecb00 fix(ui): correctly fallback to error message when traceback is empty string 2024-05-24 20:02:24 +10:00
ba4d27860f tidy(ui): remove extraneous condition in socketInvocationError 2024-05-24 20:02:24 +10:00
6fc7614b4a fix(ui): race condition with progress
There's a race condition where a canceled session may emit a progress event or two after it's been canceled, and the progress image isn't cleared out.

To resolve this, the system slice tracks canceled session ids. When a progress event comes in, we check the cancellations and skip setting the progress if canceled.
2024-05-24 20:02:24 +10:00
9c926f249f feat(processor): add debug log stmts to session running callbacks 2024-05-24 20:02:24 +10:00
80faeac913 fix(processor): fix race condition related to clearing the queue 2024-05-24 20:02:24 +10:00
418c932595 tidy(processor): remove test callbacks 2024-05-24 20:02:24 +10:00
9117db2673 tidy(queue): delete unused delete_queue_item method 2024-05-24 20:02:24 +10:00
4a48aa98a4 chore: ruff 2024-05-24 20:02:24 +10:00
e365d35c93 docs(processor): update docstrings, comments 2024-05-24 20:02:24 +10:00
aa329ea811 feat(ui): handle enriched events 2024-05-24 20:02:24 +10:00
1e622a5706 chore(ui): typegen 2024-05-24 20:02:24 +10:00
ae66d32b28 feat(app): update test event callbacks 2024-05-24 20:02:24 +10:00
2dd3a85ade feat(processor): update enriched errors & fail_queue_item() 2024-05-24 20:02:24 +10:00
a8492bd7e4 feat(events): add enriched errors to events 2024-05-24 20:02:24 +10:00
25954ea750 feat(queue): session queue error handling
- Add handling for new error columns `error_type`, `error_message`, `error_traceback`.
- Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility.
- Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args.
-
2024-05-24 20:02:24 +10:00
887b73aece feat(db): add error_type, error_message, rename error -> error_traceback to session_queue table 2024-05-24 20:02:24 +10:00
3c41c67d13 fix(processor): restore missing update of session 2024-05-24 20:02:24 +10:00
6c79be7dc3 chore: ruff 2024-05-24 20:02:24 +10:00
097619ef51 feat(processor): get user/project from queue item w/ fallback 2024-05-24 20:02:24 +10:00
a1f7a9cd6f fix(app): fix logging of error classes instead of class names 2024-05-24 20:02:24 +10:00
25b9c19eed feat(app): handle preparation errors as node errors
We were not handling node preparation errors as node errors before. Here's the explanation, copied from a comment that is no longer required:

---

TODO(psyche): Sessions only support errors on nodes, not on the session itself. When an error occurs outside
node execution, it bubbles up to the processor where it is treated as a queue item error.

Nodes are pydantic models. When we prepare a node in `session.next()`, we set its inputs. This can cause a
pydantic validation error. For example, consider a resize image node which has a constraint on its `width`
input field - it must be greater than zero. During preparation, if the width is set to zero, pydantic will
raise a validation error.

When this happens, it breaks the flow before `invocation` is set. We can't set an error on the invocation
because we didn't get far enough to get it - we don't know its id. Hence, we just set it as a queue item error.

---

This change wraps the node preparation step with exception handling. A new `NodeInputError` exception is raised when there is a validation error. This error has the node (in the state it was in just prior to the error) and an identifier of the input that failed.

This allows us to mark the node that failed preparation as errored, correctly making such errors _node_ errors and not _processor_ errors. It's much easier to diagnose these situations. The error messages look like this:

> Node b5ac87c6-0678-4b8c-96b9-d215aee12175 has invalid incoming input for height

Some of the exception handling logic is cleaned up.
2024-05-24 20:02:24 +10:00
cc2d877699 docs(app): explain why errors are handled poorly 2024-05-24 20:02:24 +10:00
be82404759 tidy(app): "outputs" -> "output" 2024-05-24 20:02:24 +10:00
33f9fe2c86 tidy(app): rearrange proccessor 2024-05-24 20:02:24 +10:00
1d973f92ff feat(app): support multiple processor lifecycle callbacks 2024-05-24 20:02:24 +10:00
7f70cde038 feat(app): make things in session runner private 2024-05-24 20:02:24 +10:00
47722528a3 feat(app): iterate on processor split 2
- Use protocol to define callbacks, this allows them to have kwargs
- Shuffle the profiler around a bit
- Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
2024-05-24 20:02:24 +10:00
be41c84305 feat(app): iterate on processor split
- Add `OnNodeError` and `OnNonFatalProcessorError` callbacks
- Move all session/node callbacks to `SessionRunner` - this ensures we dump perf stats before resetting them and generally makes sense to me
- Remove `complete` event from `SessionRunner`, it's essentially the same as `OnAfterRunSession`
- Remove extraneous `next_invocation` block, which would treat a processor error as a node error
- Simplify loops
- Add some callbacks for testing, to be removed before merge
2024-05-24 20:02:24 +10:00
82b4298b03 Fix next node calling logic 2024-05-24 20:02:24 +10:00
fa6c7badd6 Run ruff 2024-05-24 20:02:24 +10:00
45d2504c1e Break apart session processor and the running of each session into separate classes 2024-05-24 20:02:24 +10:00
f1bb7e86c0 feat(ui): invalidate cache for queue item on status change
This query is only subscribed-to in the `QueueItemDetail` component - when is rendered only when the user clicks on a queue item in the queue. Invalidating this tag instead of optimistically updating it won't cause any meaningful change to network traffic.
2024-05-24 08:59:49 +10:00
93e4c3dbc2 feat(app): update queue item's session on session completion
The session is never updated in the queue after it is first enqueued. As a result, the queue detail view in the frontend never never updates and the session itself doesn't show outputs, execution graph, etc.

We need a new method on the queue service to update a queue item's session, then call it before updating the queue item's status.

Queue item status may be updated via a session-type event _or_ queue-type event. Adding the updated session to all these events is a hairy - simpler to just update the session before we do anything that could trigger a queue item status change event:
- Before calling `emit_session_complete` in the processor (handles session error, completed and cancel events and the corresponding queue events)
- Before calling `cancel_queue_item` in the processor (handles another way queue items can be canceled, outside the session execution loop)

When serializing the session, both in the new service method and the `get_queue_item` endpoint, we need to use `exclude_none=True` to prevent unexpected validation errors.
2024-05-24 08:59:49 +10:00
c3f28f7a35 translationBot(ui): update translation (Spanish)
Currently translated at 30.5% (380 of 1243 strings)

Co-authored-by: gallegonovato <fran-carro@hotmail.es>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translation: InvokeAI/Web UI
2024-05-24 08:05:45 +10:00
c900a63842 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-05-24 08:05:45 +10:00
4eb5f004e6 Update invokeai_version.py 2024-05-24 08:00:03 +10:00
bcae735d7c fix(ui): initial image layers always ignored (#6434)
## Summary

Whoops!

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1049495067846524939/1243186572115837009

## QA Instructions

- Generate w/ initial image layer

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-05-24 03:16:18 +05:30
861f06c459 Merge branch 'main' into psyche/fix/ui/initial-image-layer 2024-05-24 03:14:18 +05:30
c493628272 fix(ui): 'undefined' being used for metadata on uploaded images (#6433)
## Summary

TIL if you add `undefined` to a form data object, it gets stringified to
`'undefined'`. Whoops!

## Related Issues / Discussions

n/a

## QA Instructions

n/a

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-05-24 03:14:02 +05:30
46a90ca402 fix(ui): initial image layers always ignored
Whoops!
2024-05-24 06:40:48 +10:00
d45c33b446 fix(ui): 'undefined' being used for metadata on uploaded images 2024-05-24 06:17:07 +10:00
88025d32c2 feat(api): downgrade metadata parse warnings to debug
I set these to warn during testing and neglected to undo the change.
2024-05-23 22:48:34 +10:00
af64764082 fix: remove db maintenance script from launcher
It is broken.
2024-05-23 22:39:55 +10:00
70487f0c2e fix(ui): layers are "enabled", not "visible" 2024-05-23 10:14:34 +10:00
55d7d9cc75 fix(ui): control layers don't disable correctly
Closes #6424
2024-05-23 10:14:34 +10:00
106674175c add logo and change text for non-local; 2024-05-23 06:51:13 +10:00
dd1d5bdb25 use support URL for non-local 2024-05-23 06:51:13 +10:00
6259ac0bec translationBot(ui): update translation (Dutch)
Currently translated at 79.6% (973 of 1222 strings)

Co-authored-by: Dennis <dennis@vanzoerlandt.nl>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/nl/
Translation: InvokeAI/Web UI
2024-05-22 09:51:12 +10:00
ba31f8a9a9 translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1210 of 1228 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1206 of 1224 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1204 of 1222 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-05-22 09:51:12 +10:00
0ba57d6dc5 feat(ui): close starter models toast when a model is installed 2024-05-22 09:40:46 +10:00
abc133e936 feat(ui): revised invocation error toast handling
Only display the session if local. Otherwise, just display the error message.
2024-05-22 09:40:46 +10:00
57743239d7 feat(ui): add updateDescription flag to toast API
If false, when updating a toast, the description is left alone. The count will still tick up.
2024-05-22 09:40:46 +10:00
4a394c60cf feat(ui): add isLocal flag to config 2024-05-22 09:40:46 +10:00
624d28a93d feat(ui): invocation error toasts do not autoclose 2024-05-22 09:40:46 +10:00
29e1ea59fc feat(ui): style copy button on ToastWithSessionRefDescription 2024-05-22 09:40:46 +10:00
2e5d24f272 tidy(ui): remove old comment 2024-05-22 09:40:46 +10:00
1afa340b1a fix(ui): show toast when recalling seed 2024-05-22 09:40:46 +10:00
3b381b5a8c tidy(ui): remove the ToastID enum
With the model install logic cleaned up the enum is less useful
2024-05-22 09:40:46 +10:00
f2b9684de8 tidy(ui): split install model into helper hook
This was duplicated like 7 times or so
2024-05-22 09:40:46 +10:00
a66b3497e0 feat(ui): port all toasts to use new util 2024-05-22 09:40:46 +10:00
683ec8e5f2 feat(ui): add stateful toast utility
Small wrapper around chakra's toast system simplifies creating and updating toasts. See comments in toast.ts for details.
2024-05-22 09:40:46 +10:00
f31f0cf733 feat(ui): restore spellcheck on prompt boxes 2024-05-22 08:52:25 +10:00
38265b3123 docs(ui): update validateWorkflow comments 2024-05-21 05:17:10 -07:00
caca28286c tests(ui): add test for resource usage check 2024-05-21 05:17:10 -07:00
38320a5100 feat(ui): reset missing images, boards and models when loading workflows
These fields are reset back to `undefined` if not accessible. A warning toast is showing, and in the JS console, the full warning message is logged.
2024-05-21 05:17:10 -07:00
7badaab17d docs: fix link to invoke ai models site 2024-05-20 20:48:42 -07:00
aa0c59bb51 fix(ui): crash when using notes nodes or missing node/field templates (#6412)
## Summary

Notes nodes used some overly-strict redux selectors. The selectors are
now more chill. Also fixed an issue where you couldn't edit a notes node
title.

Found another class of error related to the overly strict reducers that
caused errors when loading a workflow that had missing templates. Fixed
this with fallback wrapper component, works like an error boundary when
a template isn't found.

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1149506274971631688/1242256425527545949

## QA Instructions

- Add a notes node to a workflow. Edit the notes title.
- Load a workflow that has nodes that aren't installed. Should get a
fallback UI for each missing node.
- Load a workflow that references a node with different inputs than are
in the template - like an old version of a node. Should get a fallback
field warning for both missing templates, or missing inputs.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-05-21 07:59:43 +05:30
e4acaa5c8f chore: v4.2.2post1 2024-05-21 11:31:06 +10:00
9ba47cae20 fix(ui): unable to edit notes node title 2024-05-21 11:27:11 +10:00
bf4310ca71 fix(ui): errors when node template or field template doesn't exist
Some asserts were bubbling up in places where they shouldn't have, causing errors when a node has a field without a matching template, or vice-versa.

To resolve this without sacrificing the runtime safety provided by asserts, a `InvocationFieldCheck` component now wraps all field components. This component renders a fallback when a field doesn't exist, so the inner components can safely use the asserts.
2024-05-21 11:22:08 +10:00
e75f98317f fix(ui): notes node text not selectable 2024-05-21 10:06:25 +10:00
1249d4a6e3 fix(ui): crash when using a notes node 2024-05-21 10:06:09 +10:00
66c9f4708d Update invokeai_version.py 2024-05-21 07:11:09 +10:00
32277193b6 fix(ui): retain denoise strength and opacity when changing image 2024-05-20 18:27:51 +10:00
159 changed files with 3056 additions and 2987 deletions

1
.gitignore vendored
View File

@ -188,3 +188,4 @@ installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/
.aider*

View File

@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even Moar Customizing!
## Even More Customizing!
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.

View File

@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o
There are several ways to install IP-Adapter models with an existing InvokeAI installation:
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter

View File

@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
4. The VAE decodes the final latent image from latent space into image space.
Image-to-image is a similar process, with only step 1 being different:
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).

View File

@ -10,8 +10,7 @@ set INVOKEAI_ROOT=.
echo Desired action:
echo 1. Generate images with the browser-based interface
echo 2. Open the developer console
echo 3. Run the InvokeAI image database maintenance script
echo 4. Command-line help
echo 3. Command-line help
echo Q - Quit
echo.
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.
@ -34,9 +33,6 @@ IF /I "%choice%" == "1" (
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) ELSE IF /I "%choice%" == "3" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "4" (
echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %*
pause

View File

@ -47,11 +47,6 @@ do_choice() {
bash --init-file "$file_name"
;;
3)
clear
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;;
4)
clear
printf "Command-line help\n"
invokeai-web --help
@ -71,8 +66,7 @@ do_line_input() {
printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n"
printf "2: Open the developer console\n"
printf "3: Run the InvokeAI image database maintenance script\n"
printf "4: Command-line help\n"
printf "3: Command-line help\n"
printf "Q: Quit\n\n"
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
read -p "Please enter 1-4, Q: [1] " yn

View File

@ -30,7 +30,7 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
@ -103,7 +103,7 @@ class ApiDependencies:
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)

View File

@ -69,7 +69,7 @@ async def upload_image(
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
@ -77,7 +77,7 @@ async def upload_image(
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
@ -85,7 +85,7 @@ async def upload_image(
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass
try:

View File

@ -15,6 +15,7 @@ from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadEventBase,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
@ -34,21 +35,53 @@ from invokeai.app.services.events.events_common import (
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_download_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
class SocketIO:
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
@ -66,45 +99,9 @@ class SocketIO:
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
register_events(
{
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
SessionStartedEvent,
SessionCompleteEvent,
SessionCanceledEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
},
self._handle_queue_event,
)
register_events(
{
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
},
self._handle_model_event,
)
register_events(
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
self._handle_bulk_image_download_event,
)
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
@ -119,13 +116,10 @@ class SocketIO:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.queue_id)
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"))
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.bulk_download_id)
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)

View File

@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
# apply all patches while the model is on the target device
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
tokenizer=tokenizer,
tokenizer=patched_tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -136,11 +134,7 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@ -177,20 +171,23 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
# apply all patches while the model is on the target device
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
tokenizer=patched_tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -203,7 +200,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization:
# TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer)
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):

View File

@ -14,7 +14,6 @@ from invokeai.app.services.events.events_common import (
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
ExtraData,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
@ -29,9 +28,6 @@ from invokeai.app.services.events.events_common import (
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@ -58,11 +54,9 @@ class EventServiceBase:
# region: Invocation
def emit_invocation_started(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
) -> None:
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def emit_invocation_denoise_progress(
self,
@ -70,184 +64,132 @@ class EventServiceBase:
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
)
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
def emit_invocation_complete(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
output: "BaseInvocationOutput",
extra: Optional[ExtraData] = None,
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
) -> None:
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
def emit_invocation_error(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
error_message: str,
error_traceback: str,
) -> None:
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
# endregion
# region Session
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has started"""
self.dispatch(SessionStartedEvent.build(queue_item, extra))
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has completed all invocations"""
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session is canceled"""
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
# endregion
# region Queue
def emit_queue_item_status_changed(
self,
queue_item: "SessionQueueItem",
batch_status: "BatchStatus",
queue_status: "SessionQueueStatus",
extra: Optional[ExtraData] = None,
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
) -> None:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id, extra))
self.dispatch(QueueClearedEvent.build(queue_id))
# endregion
# region Download
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
def emit_download_started(self, job: "DownloadJob") -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job, extra))
self.dispatch(DownloadStartedEvent.build(job))
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
def emit_download_progress(self, job: "DownloadJob") -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job, extra))
self.dispatch(DownloadProgressEvent.build(job))
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
def emit_download_complete(self, job: "DownloadJob") -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job, extra))
self.dispatch(DownloadCompleteEvent.build(job))
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
def emit_download_cancelled(self, job: "DownloadJob") -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job, extra))
self.dispatch(DownloadCancelledEvent.build(job))
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
def emit_download_error(self, job: "DownloadJob") -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job, extra))
self.dispatch(DownloadErrorEvent.build(job))
# endregion
# region Model loading
def emit_model_load_started(
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
) -> None:
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
def emit_model_load_complete(
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
# endregion
# region Model install
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job, extra))
self.dispatch(ModelInstallStartedEvent.build(job))
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
self.dispatch(ModelInstallCompleteEvent.build(job))
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
self.dispatch(ModelInstallCancelledEvent.build(job))
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job, extra))
self.dispatch(ModelInstallErrorEvent.build(job))
# endregion
# region Bulk image download
def emit_bulk_download_started(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
)
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_complete(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
)
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_error(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
)
# endregion

View File

@ -1,8 +1,9 @@
from math import floor
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
@ -22,9 +23,6 @@ if TYPE_CHECKING:
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
ExtraData: TypeAlias = dict[str, Any]
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
@ -35,8 +33,8 @@ class EventBase(BaseModel):
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@ -54,7 +52,7 @@ class EventBase(BaseModel):
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase)
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
@ -63,16 +61,17 @@ Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol):
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
class FastAPIEventFunc(Protocol, Generic[TEvent]):
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
"""Register a function to handle a list of events.
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
"""Register a function to handle specific events.
:param events: A list of event classes to handle
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
@ -91,45 +90,45 @@ class QueueItemEventBase(QueueEventBase):
batch_id: str = Field(description="The ID of the queue batch")
class SessionEventBase(QueueItemEventBase):
"""Base class for session (aka graph execution state) events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
class InvocationEventBase(SessionEventBase):
class InvocationEventBase(QueueItemEventBase):
"""Base class for invocation events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation_id: str = Field(description="The ID of the invocation")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
invocation_type: str = Field(description="The type of invocation")
@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""
invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
) -> "InvocationStartedEvent":
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
extra=extra,
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
@ -148,7 +147,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
extra: Optional[ExtraData] = None,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
@ -158,15 +156,13 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
extra=extra,
)
@staticmethod
@ -180,6 +176,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
@ -187,34 +184,40 @@ class InvocationCompleteEvent(InvocationEventBase):
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
result: BaseInvocationOutput,
extra: Optional[ExtraData] = None,
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
result=result,
extra=extra,
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
@classmethod
def build(
@ -222,109 +225,65 @@ class InvocationErrorEvent(InvocationEventBase):
queue_item: SessionQueueItem,
invocation: BaseInvocation,
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
error_type=error_type,
error=error,
extra=extra,
)
class SessionStartedEvent(SessionEventBase):
"""Event model for session_started"""
__event_name__ = "session_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class SessionCompleteEvent(SessionEventBase):
"""Event model for session_complete"""
__event_name__ = "session_complete"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class SessionCanceledEvent(SessionEventBase):
"""Event model for session_canceled"""
__event_name__ = "session_canceled"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error: Optional[str] = Field(default=None, description="The error message, if any")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
extra: Optional[ExtraData] = None,
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
status=queue_item.status,
error=queue_item.error,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
extra=extra,
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
@ -338,28 +297,25 @@ class BatchEnqueuedEvent(QueueEventBase):
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
extra=extra,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
return cls(
queue_id=queue_id,
extra=extra,
)
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEventBase(EventBase):
@ -368,6 +324,7 @@ class DownloadEventBase(EventBase):
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
@ -376,15 +333,12 @@ class DownloadStartedEvent(DownloadEventBase):
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
extra=extra,
)
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
@ -395,17 +349,17 @@ class DownloadProgressEvent(DownloadEventBase):
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
@ -415,29 +369,23 @@ class DownloadCompleteEvent(DownloadEventBase):
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
total_bytes=job.total_bytes,
extra=extra,
)
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
return cls(
source=str(job.source),
extra=extra,
)
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
@ -447,21 +395,17 @@ class DownloadErrorEvent(DownloadEventBase):
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
@ -471,16 +415,11 @@ class ModelLoadStartedEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadStartedEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
@ -490,16 +429,11 @@ class ModelLoadCompleteEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadCompleteEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
@ -515,7 +449,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
)
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
@ -532,10 +466,10 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
@ -545,14 +479,11 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
@ -562,14 +493,11 @@ class ModelInstallStartedEvent(ModelEventBase):
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
@ -581,17 +509,12 @@ class ModelInstallCompleteEvent(ModelEventBase):
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(
id=job.id,
source=str(job.source),
key=(job.config_out.key),
total_bytes=job.total_bytes,
extra=extra,
)
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
@ -601,14 +524,11 @@ class ModelInstallCancelledEvent(ModelEventBase):
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
@ -620,16 +540,10 @@ class ModelInstallErrorEvent(ModelEventBase):
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(
id=job.id,
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):
@ -640,6 +554,7 @@ class BulkDownloadEventBase(EventBase):
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
@ -647,20 +562,16 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
@ -668,20 +579,16 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
@ -691,17 +598,11 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
extra=extra,
)

View File

@ -36,6 +36,7 @@ class FastAPIEventService(EventServiceBase):
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:

View File

@ -72,6 +72,6 @@ class ModelLoadService(ModelLoadServiceBase):
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model

View File

@ -1,6 +1,49 @@
from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner.
Args:
services: The invocation services.
cancel_event: The cancel event.
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
stats will be still be recorded and logged when profiling is disabled.
"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs a session.
Args:
queue_item: The session to run.
"""
pass
@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Run a single node in the graph.
Args:
invocation: The invocation to run.
queue_item: The session queue item.
"""
pass
class SessionProcessorBase(ABC):
@ -26,3 +69,85 @@ class SessionProcessorBase(ABC):
def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor"""
pass
class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that will be executed.
queue_item: The session queue item.
"""
...
class OnAfterRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that was executed.
queue_item: The session queue item.
"""
...
class OnNodeError(Protocol):
def __call__(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a node has an error.
Args:
invocation: The invocation that errored.
queue_item: The session queue item.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...
class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run after executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnNonFatalProcessorError(Protocol):
def __call__(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a non-fatal error occurs in the processor.
Args:
queue_item: The session queue item, if one was being executed when the error occurred.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...

View File

@ -4,29 +4,325 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
register_events,
)
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode,
OnAfterRunSession,
OnBeforeRunNode,
OnBeforeRunSession,
OnNodeError,
OnNonFatalProcessorError,
)
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations."""
def __init__(
self,
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
):
"""
Args:
on_before_run_session_callbacks: Callbacks to run before the session starts.
on_before_run_node_callbacks: Callbacks to run before each node starts.
on_after_run_node_callbacks: Callbacks to run after each node completes.
on_node_error_callbacks: Callbacks to run when a node errors.
on_after_run_session_callbacks: Callbacks to run after the session completes.
"""
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
def _is_canceled(self) -> bool:
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
denoising to check if the session has been canceled."""
return self._cancel_event.is_set()
def run(self, queue_item: SessionQueueItem):
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
self._on_before_run_session(queue_item=queue_item)
# Loop over invocations until the session is complete or canceled
while True:
try:
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=e.node,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
break
if invocation is None or self._is_canceled():
break
self.run_node(invocation, queue_item)
# The session is complete if all invocations have been run or there is an error on the session.
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
# use the cancel event to check if the session is canceled.
if (
queue_item.session.is_complete()
or self._is_canceled()
or queue_item.status in ["failed", "canceled", "completed"]
):
break
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
is_canceled=self._is_canceled,
)
# Invoke the node
output = invocation.invoke_internal(context=context, services=self._services)
# Save output and history
queue_item.session.complete(invocation.id, output)
self._on_after_run_node(invocation, queue_item, output)
except KeyboardInterrupt:
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
pass
except CanceledException:
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
# correctly in the next iteration of the session runner loop.
#
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
# handle cancellation.
pass
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called before a session is run.
- Start the profiler if profiling is enabled.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
for callback in self._on_before_run_session_callbacks:
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called after a session is run.
- Stop the profiler if profiling is enabled.
- Update the queue item's session object in the database.
- If not already canceled or failed, complete the queue item.
- Log and reset performance statistics.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler is not None:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
try:
# Update the queue item with the completed session. If the queue item has been removed from the queue,
# we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared
# while the session is running.
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# The queue item may have been canceled or failed while the session was running. We should only complete it
# if it is not already canceled or failed.
if queue_item.status not in ["canceled", "failed"]:
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
except SessionQueueItemNotFoundError:
pass
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Called before a node is run.
- Emits an invocation started event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send starting event
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
for callback in self._on_before_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item)
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
):
"""Called after a node is run.
- Emits an invocation complete event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send complete event on successful runs
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
for callback in self._on_after_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item, output=output)
def _on_node_error(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
):
"""Called when a node errors. Node errors may occur when running or preparing the node..
- Set the node error on the session object.
- Log the error.
- Fail the queue item.
- Emits an invocation error event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, node_error)
self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
)
self._services.logger.error(error_traceback)
# Fail the queue item
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
queue_item = self._services.session_queue.fail_queue_item(
queue_item.item_id, error_type, error_message, error_traceback
)
# Send error event
self._services.events.emit_invocation_error(
queue_item=queue_item,
invocation=invocation,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_node_error_callbacks:
callback(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
def __init__(
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
self._thread_limit = thread_limit
self._polling_interval = polling_interval
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
@ -36,9 +332,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
register_events(QueueClearedEvent, self._on_queue_cleared)
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session.
@ -52,8 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
register_events({SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent}, self._on_queue_event)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
self._thread = Thread(
name="session_processor",
target=self._process,
@ -72,25 +369,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
_event_name, payload = event
if (
isinstance(payload, SessionCanceledEvent)
and self._queue_item
and self._queue_item.item_id == payload.item_id
):
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
self._cancel_event.set()
self._poll_now()
elif (
isinstance(payload, QueueClearedEvent)
and self._queue_item
and self._queue_item.queue_id == payload.queue_id
):
self._cancel_event.set()
self._poll_now()
elif isinstance(payload, BatchEnqueuedEvent):
self._poll_now()
elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ("completed", "failed", "canceled"):
async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
#
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
if event[1].status == "canceled":
self._cancel_event.set()
self._poll_now()
def resume(self) -> SessionProcessorStatus:
@ -116,8 +413,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event: ThreadEvent,
cancel_event: ThreadEvent,
):
# Outermost processor try block; any unhandled exception is a fatal processor error
try:
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
self._thread_semaphore.acquire()
stop_event.clear()
resume_event.set()
@ -125,8 +422,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
while not stop_event.is_set():
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
# If we are paused, wait for resume event
resume_event.wait()
@ -139,140 +436,72 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
self._invoker.services.events.emit_session_started(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
self._invoker.services.events.emit_invocation_started(self._queue_item, self._invocation)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
self._invoker.services.events.emit_invocation_complete(
self._queue_item, self._invocation, outputs
)
except KeyboardInterrupt:
# TODO(MM2): I don't think this is ever raised...
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
self._invoker.services.events.emit_invocation_error(
queue_item=self._queue_item,
invocation=self._invocation,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
)
self._invoker.services.events.emit_session_complete(self._queue_item)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(
queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
)
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)
# Reset the invocation to None to prepare for the next session
self._invocation = None
# Immediately poll for next queue item
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
except Exception:
except Exception as e:
# Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
pass
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Called when a non-fatal error occurs in the processor.
- Log the error.
- If a queue item is provided, update the queue item with the completed session & fail it.
- Run any callbacks registered for this event.
"""
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
if queue_item is not None:
# Update the queue item with the completed session & fail it
queue_item = self._invoker.services.session_queue.set_queue_item_session(
queue_item.item_id, queue_item.session
)
queue_item = self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)

View File

@ -74,10 +74,22 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
"""Completes a session queue item"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item"""
pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
) -> SessionQueueItem:
"""Fails a session queue item"""
pass
@abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs"""

View File

@ -3,7 +3,16 @@ import json
from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")

View File

@ -2,13 +2,6 @@ import sqlite3
import threading
from typing import Optional, Union, cast
from invokeai.app.services.events.events_common import (
FastAPIEvent,
InvocationErrorEvent,
SessionCanceledEvent,
SessionCompleteEvent,
register_events,
)
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
@ -46,10 +39,6 @@ class SqliteSessionQueue(SessionQueueBase):
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@ -59,36 +48,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
try:
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
_event_name, payload = event
queue_item = self.get_queue_item(payload.item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
self._set_queue_item_status(item_id=payload.item_id, status="completed")
except SessionQueueItemNotFoundError:
pass
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
try:
_event_name, payload = event
# always set to failed if have an error, even if previously the item was marked completed or canceled
self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error)
except SessionQueueItemNotFoundError:
pass
async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
try:
_event_name, payload = event
queue_item = self.get_queue_item(payload.item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
self._set_queue_item_status(item_id=payload.item_id, status="canceled")
except SessionQueueItemNotFoundError:
pass
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
@ -263,17 +222,22 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result))
def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
self,
item_id: int,
status: QUEUE_ITEM_STATUS,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error = ?
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
WHERE item_id = ?
""",
(status, error, item_id),
(status, error_type, error_message, error_traceback, item_id),
)
self.__conn.commit()
except Exception:
@ -326,26 +290,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return IsFullResult(is_full=is_full)
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult:
try:
self.__lock.acquire()
@ -412,12 +356,28 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
self.__invoker.services.events.emit_session_canceled(queue_item)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
return queue_item
def fail_queue_item(
self,
item_id: int,
error_type: str,
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@ -453,7 +413,6 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(current_queue_item)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
@ -497,7 +456,6 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(current_queue_item)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
@ -570,7 +528,9 @@ class SqliteSessionQueue(SessionQueueBase):
status,
priority,
field_values,
error,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,

View File

@ -8,6 +8,7 @@ import networkx as nx
from pydantic import (
BaseModel,
GetJsonSchemaHandler,
ValidationError,
field_validator,
)
from pydantic.fields import Field
@ -190,6 +191,39 @@ class UnknownGraphValidationError(ValueError):
pass
class NodeInputError(ValueError):
"""Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an
input fails validation.
Attributes:
node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to
determine which field caused the failure.
"""
def __init__(self, node: BaseInvocation, e: ValidationError):
self.original_error = e
self.node = node
# When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error
# represents the first input that failed.
self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"])
super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}")
def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str:
"""Helper to pretty-print pydantic error locations as dot-separated strings.
Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
"""
path = ""
for i, x in enumerate(loc):
if isinstance(x, str):
if i > 0:
path += "."
path += x
else:
path += f"[{x}]"
return path
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
@ -821,7 +855,10 @@ class GraphExecutionState(BaseModel):
# Get values from edges
if next_node is not None:
self._prepare_inputs(next_node)
try:
self._prepare_inputs(next_node)
except ValidationError as e:
raise NodeInputError(next_node, e)
# If next is still none, there's no next node, return None
return next_node

View File

@ -1,7 +1,6 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union
from PIL.Image import Image
from torch import Tensor
@ -449,10 +448,10 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface):
def __init__(
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
) -> None:
super().__init__(services, data)
self._cancel_event = cancel_event
self._is_canceled = is_canceled
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
@ -460,7 +459,7 @@ class UtilInterface(InvocationContextInterface):
Returns:
True if the current session has been canceled, False if not.
"""
return self._cancel_event.is_set()
return self._is_canceled()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
@ -535,7 +534,7 @@ class InvocationContext:
def build_invocation_context(
services: InvocationServices,
data: InvocationContextData,
cancel_event: threading.Event,
is_canceled: Callable[[], bool],
) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution.
@ -552,7 +551,7 @@ def build_invocation_context(
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data)

View File

@ -12,6 +12,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -41,6 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_7())
migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.run_migrations()
return db

View File

@ -0,0 +1,35 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration10Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_error_cols(cursor)
def _update_error_cols(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;")
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;")
cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;")
def build_migration_10() -> Migration:
"""
Build the migration from database version 9 to 10.
This migration does the following:
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
migration_10 = Migration(
from_version=9,
to_version=10,
callback=Migration10Callback(),
)
return migration_10

View File

@ -30,8 +30,12 @@ def convert_ldm_vae_to_diffusers(
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.to(precision)
with torch.no_grad():
vae.load_state_dict(converted_vae_checkpoint)
del converted_vae_checkpoint # Free memory
import gc
gc.collect()
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
@ -52,7 +56,11 @@ def convert_ckpt_to_diffusers(
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
pipe = pipe.to(precision)
with torch.no_grad():
del kwargs # Free memory
import gc
gc.collect()
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
@ -75,7 +83,11 @@ def convert_controlnet_to_diffusers(
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
pipe = pipe.to(precision)
with torch.no_grad():
del kwargs # Free memory
import gc
gc.collect()
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:

View File

@ -42,10 +42,26 @@ T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""Elements of the cache."""
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
"""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0

View File

@ -20,7 +20,6 @@ context. Use like this:
import gc
import math
import sys
import time
from contextlib import suppress
from logging import Logger
@ -162,7 +161,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
if key in self._cached_models:
return
self.make_room(size)
cache_record = CacheRecord(key, model, size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@ -257,17 +258,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.model.device
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
@ -347,43 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)

View File

@ -1,7 +1,7 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0:
return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
# Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
# Do not exceed the max model input size. The -2 here is compensating for
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
max_length = self.tokenizer.model_max_length - 2
if len(new_token_ids) > max_length:
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
# prompts.
new_token_ids = new_token_ids[0:max_length]
return new_token_ids

View File

@ -2,6 +2,7 @@
"accessibility": {
"about": "About",
"createIssue": "Create Issue",
"submitSupportTicket": "Submit Support Ticket",
"invokeProgressBar": "Invoke progress bar",
"menu": "Menu",
"mode": "Mode",
@ -146,7 +147,9 @@
"viewing": "Viewing",
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas"
"editingDesc": "Edit on the Control Layers canvas",
"enabled": "Enabled",
"disabled": "Disabled"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
@ -897,7 +900,10 @@
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default"
},
"parameters": {
"aspect": "Aspect",
@ -1070,8 +1076,9 @@
},
"toast": {
"addedToBoard": "Added to board",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Base model changed, cleared or disabled {{count}} incompatible submodels",
"baseModelChanged": "Base Model Changed",
"baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels",
"canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded",
@ -1092,10 +1099,17 @@
"metadataLoadFailed": "Failed to load metadata",
"modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled",
"outOfMemoryError": "Out of Memory Error",
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
"parameters": "Parameters",
"parameterNotSet": "{{parameter}} not set",
"parameterSet": "{{parameter}} set",
"parametersNotSet": "Parameters Not Set",
"parameterSet": "Parameter Recalled",
"parameterSetDesc": "Recalled {{parameter}}",
"parameterNotSet": "Parameter Not Recalled",
"parameterNotSetDesc": "Unable to recall {{parameter}}",
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
"parametersSet": "Parameters Recalled",
"parametersNotSet": "Parameters Not Recalled",
"errorCopied": "Error Copied",
"problemCopyingCanvas": "Problem Copying Canvas",
"problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image",
@ -1115,11 +1129,13 @@
"sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas",
"serverError": "Server Error",
"sessionRef": "Session: {{sessionId}}",
"setAsCanvasInitialImage": "Set as canvas initial image",
"setCanvasInitialImage": "Set canvas initial image",
"setControlImage": "Set as control image",
"setInitialImage": "Set as initial image",
"setNodeField": "Set as node field",
"somethingWentWrong": "Something Went Wrong",
"uploadFailed": "Upload failed",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"uploadInitialImage": "Upload Initial Image",
@ -1559,7 +1575,6 @@
"controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"toggleVisibility": "Toggle Layer Visibility",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",

View File

@ -382,7 +382,7 @@
"canvasMerged": "Lienzo consolidado",
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
"parametersNotSet": "Parámetros no establecidos",
"parametersNotSet": "Parámetros no recuperados",
"metadataLoadFailed": "Error al cargar metadatos",
"serverError": "Error en el servidor",
"canceled": "Procesando la cancelación",
@ -390,7 +390,8 @@
"uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG",
"parameterSet": "Conjunto de parámetros",
"parameterNotSet": "Parámetro no configurado",
"problemCopyingImage": "No se puede copiar la imagen"
"problemCopyingImage": "No se puede copiar la imagen",
"errorCopied": "Error al copiar"
},
"tooltip": {
"feature": {

View File

@ -524,7 +524,20 @@
"missingNodeTemplate": "Modello di nodo mancante",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante",
"missingFieldTemplate": "Modello di campo mancante",
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata"
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata",
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
"ipAdapterNoModelSelected": "Nessun adattatore IP selezionato",
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
"rgNoRegion": "Nessuna regione selezionata"
}
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@ -824,8 +837,8 @@
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi",
"addLinearView": "Aggiungi alla vista Lineare",
"unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro",
"collectionFieldType": "{{name}} Raccolta",
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare",
"collectionFieldType": "{{name}} (Raccolta)",
"collectionOrScalarFieldType": "{{name}} (Singola o Raccolta)",
"nodeVersion": "Versione Nodo",
"inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})",
"unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"",
@ -863,7 +876,13 @@
"edit": "Modifica",
"graph": "Grafico",
"showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati",
"showEdgeLabels": "Mostra le etichette del collegamento"
"showEdgeLabels": "Mostra le etichette del collegamento",
"cannotMixAndMatchCollectionItemTypes": "Impossibile combinare e abbinare i tipi di elementi della raccolta",
"noGraph": "Nessun grafico",
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@ -1034,7 +1053,16 @@
"graphFailedToQueue": "Impossibile mettere in coda il grafico",
"batchFieldValues": "Valori Campi Lotto",
"time": "Tempo",
"openQueue": "Apri coda"
"openQueue": "Apri coda",
"iterations_one": "Iterazione",
"iterations_many": "Iterazioni",
"iterations_other": "Iterazioni",
"prompts_one": "Prompt",
"prompts_many": "Prompt",
"prompts_other": "Prompt",
"generations_one": "Generazione",
"generations_many": "Generazioni",
"generations_other": "Generazioni"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@ -1563,7 +1591,6 @@
"brushSize": "Dimensioni del pennello",
"globalMaskOpacity": "Opacità globale della maschera",
"autoNegative": "Auto Negativo",
"toggleVisibility": "Attiva/disattiva la visibilità dei livelli",
"deletePrompt": "Cancella il prompt",
"debugLayers": "Debug dei Livelli",
"rectangle": "Rettangolo",

View File

@ -6,7 +6,7 @@
"settingsLabel": "Instellingen",
"img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas",
"nodes": "Werkstroom-editor",
"nodes": "Werkstromen",
"upload": "Upload",
"load": "Laad",
"statusDisconnected": "Niet verbonden",
@ -34,7 +34,60 @@
"controlNet": "ControlNet",
"imageFailedToLoad": "Kan afbeelding niet laden",
"learnMore": "Meer informatie",
"advanced": "Uitgebreid"
"advanced": "Uitgebreid",
"file": "Bestand",
"installed": "Geïnstalleerd",
"notInstalled": "Niet $t(common.installed)",
"simple": "Eenvoudig",
"somethingWentWrong": "Er ging iets mis",
"add": "Voeg toe",
"checkpoint": "Checkpoint",
"details": "Details",
"outputs": "Uitvoeren",
"save": "Bewaar",
"nextPage": "Volgende pagina",
"blue": "Blauw",
"alpha": "Alfa",
"red": "Rood",
"editor": "Editor",
"folder": "Map",
"format": "structuur",
"goTo": "Ga naar",
"template": "Sjabloon",
"input": "Invoer",
"loglevel": "Logboekniveau",
"safetensors": "Safetensors",
"saveAs": "Bewaar als",
"created": "Gemaakt",
"green": "Groen",
"tab": "Tab",
"positivePrompt": "Positieve prompt",
"negativePrompt": "Negatieve prompt",
"selected": "Geselecteerd",
"orderBy": "Sorteer op",
"prevPage": "Vorige pagina",
"beta": "Bèta",
"copyError": "$t(gallery.copy) Fout",
"toResolve": "Op te lossen",
"aboutDesc": "Gebruik je Invoke voor het werk? Kijk dan naar:",
"aboutHeading": "Creatieve macht voor jou",
"copy": "Kopieer",
"data": "Gegevens",
"or": "of",
"updated": "Bijgewerkt",
"outpaint": "outpainten",
"viewing": "Bekijken",
"viewingDesc": "Beoordeel afbeelding in een grote galerijweergave",
"editing": "Bewerken",
"editingDesc": "Bewerk op het canvas Stuurlagen",
"ai": "ai",
"inpaint": "inpainten",
"unknown": "Onbekend",
"delete": "Verwijder",
"direction": "Richting",
"error": "Fout",
"localSystem": "Lokaal systeem",
"unknownError": "Onbekende fout"
},
"gallery": {
"galleryImageSize": "Afbeeldingsgrootte",
@ -310,10 +363,41 @@
"modelSyncFailed": "Synchronisatie modellen mislukt",
"modelDeleteFailed": "Model kon niet verwijderd worden",
"convertingModelBegin": "Model aan het converteren. Even geduld.",
"predictionType": "Soort voorspelling (voor Stable Diffusion 2.x-modellen en incidentele Stable Diffusion 1.x-modellen)",
"predictionType": "Soort voorspelling",
"advanced": "Uitgebreid",
"modelType": "Soort model",
"vaePrecision": "Nauwkeurigheid VAE"
"vaePrecision": "Nauwkeurigheid VAE",
"loraTriggerPhrases": "LoRA-triggerzinnen",
"urlOrLocalPathHelper": "URL's zouden moeten wijzen naar een los bestand. Lokale paden kunnen wijzen naar een los bestand of map voor een individueel Diffusers-model.",
"modelName": "Modelnaam",
"path": "Pad",
"triggerPhrases": "Triggerzinnen",
"typePhraseHere": "Typ zin hier in",
"useDefaultSettings": "Gebruik standaardinstellingen",
"modelImageDeleteFailed": "Fout bij verwijderen modelafbeelding",
"modelImageUpdated": "Modelafbeelding bijgewerkt",
"modelImageUpdateFailed": "Fout bij bijwerken modelafbeelding",
"noMatchingModels": "Geen overeenkomende modellen",
"scanPlaceholder": "Pad naar een lokale map",
"noModelsInstalled": "Geen modellen geïnstalleerd",
"noModelsInstalledDesc1": "Installeer modellen met de",
"noModelSelected": "Geen model geselecteerd",
"starterModels": "Beginnermodellen",
"textualInversions": "Tekstuele omkeringen",
"upcastAttention": "Upcast-aandacht",
"uploadImage": "Upload afbeelding",
"mainModelTriggerPhrases": "Triggerzinnen hoofdmodel",
"urlOrLocalPath": "URL of lokaal pad",
"scanFolderHelper": "De map zal recursief worden ingelezen voor modellen. Dit kan enige tijd in beslag nemen voor erg grote mappen.",
"simpleModelPlaceholder": "URL of pad naar een lokaal pad of Diffusers-map",
"modelSettings": "Modelinstellingen",
"pathToConfig": "Pad naar configuratie",
"prune": "Snoei",
"pruneTooltip": "Snoei voltooide importeringen uit wachtrij",
"repoVariant": "Repovariant",
"scanFolder": "Lees map in",
"scanResults": "Resultaten inlezen",
"source": "Bron"
},
"parameters": {
"images": "Afbeeldingen",
@ -353,13 +437,13 @@
"copyImage": "Kopieer afbeelding",
"denoisingStrength": "Sterkte ontruisen",
"scheduler": "Planner",
"seamlessXAxis": "X-as",
"seamlessYAxis": "Y-as",
"seamlessXAxis": "Naadloze tegels in x-as",
"seamlessYAxis": "Naadloze tegels in y-as",
"clipSkip": "Overslaan CLIP",
"negativePromptPlaceholder": "Negatieve prompt",
"controlNetControlMode": "Aansturingsmodus",
"positivePromptPlaceholder": "Positieve prompt",
"maskBlur": "Vervaag",
"maskBlur": "Vervaging van masker",
"invoke": {
"noNodesInGraph": "Geen knooppunten in graaf",
"noModelSelected": "Geen model ingesteld",
@ -369,11 +453,25 @@
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt",
"noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding",
"noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.",
"incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is ongeldig in combinatie met het hoofdmodel.",
"incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is niet compatibel met het hoofdmodel.",
"systemDisconnected": "Systeem is niet verbonden",
"missingNodeTemplate": "Knooppuntsjabloon ontbreekt",
"missingFieldTemplate": "Veldsjabloon ontbreekt",
"addingImagesTo": "Bezig met toevoegen van afbeeldingen aan"
"addingImagesTo": "Bezig met toevoegen van afbeeldingen aan",
"layer": {
"initialImageNoImageSelected": "geen initiële afbeelding geselecteerd",
"controlAdapterNoModelSelected": "geen controle-adaptermodel geselecteerd",
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
"controlAdapterNoImageSelected": "geen afbeelding voor controle-adapter geselecteerd",
"controlAdapterImageNotProcessed": "Afbeelding voor controle-adapter niet verwerkt",
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
"rgNoRegion": "geen gebied geselecteerd",
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
"t2iAdapterIncompatibleDimensions": "T2I-adapter vereist een afbeelding met afmetingen met een veelvoud van 64",
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
},
"imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt"
},
"isAllowedToUpscale": {
"useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model",
@ -383,7 +481,26 @@
"useCpuNoise": "Gebruik CPU-ruis",
"imageActions": "Afbeeldingshandeling",
"iterations": "Iteraties",
"coherenceMode": "Modus"
"coherenceMode": "Modus",
"infillColorValue": "Vulkleur",
"remixImage": "Meng afbeelding opnieuw",
"setToOptimalSize": "Optimaliseer grootte voor het model",
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (is mogelijk te klein)",
"aspect": "Beeldverhouding",
"infillMosaicTileWidth": "Breedte tegel",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (is mogelijk te groot)",
"lockAspectRatio": "Zet beeldverhouding vast",
"infillMosaicTileHeight": "Hoogte tegel",
"globalNegativePromptPlaceholder": "Globale negatieve prompt",
"globalPositivePromptPlaceholder": "Globale positieve prompt",
"useSize": "Gebruik grootte",
"swapDimensions": "Wissel afmetingen om",
"globalSettings": "Globale instellingen",
"coherenceEdgeSize": "Randgrootte",
"coherenceMinDenoise": "Min. ontruising",
"infillMosaicMinColor": "Min. kleur",
"infillMosaicMaxColor": "Max. kleur",
"cfgRescaleMultiplier": "Vermenigvuldiger voor CFG-herschaling"
},
"settings": {
"models": "Modellen",
@ -410,7 +527,12 @@
"intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist",
"intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist",
"clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.",
"intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen"
"intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen",
"clearIntermediatesDisabled": "Wachtrij moet leeg zijn om tussentijdse afbeeldingen te kunnen leegmaken",
"enableInformationalPopovers": "Schakel informatieve hulpballonnen in",
"enableInvisibleWatermark": "Schakel onzichtbaar watermerk in",
"enableNSFWChecker": "Schakel NSFW-controle in",
"reloadingIn": "Opnieuw laden na"
},
"toast": {
"uploadFailed": "Upload mislukt",
@ -425,8 +547,8 @@
"connected": "Verbonden met server",
"canceled": "Verwerking geannuleerd",
"uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn",
"parameterNotSet": "Parameter niet ingesteld",
"parameterSet": "Instellen parameters",
"parameterNotSet": "{{parameter}} niet ingesteld",
"parameterSet": "{{parameter}} ingesteld",
"problemCopyingImage": "Kan Afbeelding Niet Kopiëren",
"baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld",
"baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld",
@ -443,11 +565,11 @@
"maskSavedAssets": "Masker bewaard in Assets",
"problemDownloadingCanvas": "Fout bij downloaden van canvas",
"problemMergingCanvas": "Fout bij samenvoegen canvas",
"setCanvasInitialImage": "Ingesteld als initiële canvasafbeelding",
"setCanvasInitialImage": "Initiële canvasafbeelding ingesteld",
"imageUploaded": "Afbeelding geüpload",
"addedToBoard": "Toegevoegd aan bord",
"workflowLoaded": "Werkstroom geladen",
"modelAddedSimple": "Model toegevoegd",
"modelAddedSimple": "Model toegevoegd aan wachtrij",
"problemImportingMaskDesc": "Kan masker niet exporteren",
"problemCopyingCanvas": "Fout bij kopiëren canvas",
"problemSavingCanvas": "Fout bij bewaren canvas",
@ -459,7 +581,18 @@
"maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets",
"canvasSavedGallery": "Canvas bewaard in galerij",
"imageUploadFailed": "Fout bij uploaden afbeelding",
"problemImportingMask": "Fout bij importeren masker"
"problemImportingMask": "Fout bij importeren masker",
"workflowDeleted": "Werkstroom verwijderd",
"invalidUpload": "Ongeldige upload",
"uploadInitialImage": "Initiële afbeelding uploaden",
"setAsCanvasInitialImage": "Ingesteld als initiële afbeelding voor canvas",
"problemRetrievingWorkflow": "Fout bij ophalen van werkstroom",
"parameters": "Parameters",
"modelImportCanceled": "Importeren model geannuleerd",
"problemDeletingWorkflow": "Fout bij verwijderen van werkstroom",
"prunedQueue": "Wachtrij gesnoeid",
"problemDownloadingImage": "Fout bij downloaden afbeelding",
"resetInitialImage": "Initiële afbeelding hersteld"
},
"tooltip": {
"feature": {
@ -533,7 +666,11 @@
"showOptionsPanel": "Toon zijscherm",
"menu": "Menu",
"showGalleryPanel": "Toon deelscherm Galerij",
"loadMore": "Laad meer"
"loadMore": "Laad meer",
"about": "Over",
"mode": "Modus",
"resetUI": "$t(accessibility.reset) UI",
"createIssue": "Maak probleem aan"
},
"nodes": {
"zoomOutNodes": "Uitzoomen",
@ -547,7 +684,7 @@
"loadWorkflow": "Laad werkstroom",
"downloadWorkflow": "Download JSON van werkstroom",
"scheduler": "Planner",
"missingTemplate": "Ontbrekende sjabloon",
"missingTemplate": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een ontbrekend sjabloon (niet geïnstalleerd?)",
"workflowDescription": "Korte beschrijving",
"versionUnknown": " Versie onbekend",
"noNodeSelected": "Geen knooppunt gekozen",
@ -563,7 +700,7 @@
"integer": "Geheel getal",
"nodeTemplate": "Sjabloon knooppunt",
"nodeOpacity": "Dekking knooppunt",
"unableToLoadWorkflow": "Kan werkstroom niet valideren",
"unableToLoadWorkflow": "Fout bij laden werkstroom",
"snapToGrid": "Lijn uit op raster",
"noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave",
"nodeSearch": "Zoek naar knooppunten",
@ -614,11 +751,56 @@
"unknownField": "Onbekend veld",
"colorCodeEdges": "Kleurgecodeerde randen",
"unknownNode": "Onbekend knooppunt",
"mismatchedVersion": "Heeft niet-overeenkomende versie",
"mismatchedVersion": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een niet-overeenkomende versie (probeer het bij te werken?)",
"addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)",
"loadingNodes": "Bezig met laden van knooppunten...",
"snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing",
"workflowSettings": "Instellingen werkstroomeditor"
"workflowSettings": "Instellingen werkstroomeditor",
"addLinearView": "Voeg toe aan lineaire weergave",
"nodePack": "Knooppuntpakket",
"unknownInput": "Onbekende invoer: {{name}}",
"sourceNodeFieldDoesNotExist": "Ongeldige rand: bron-/uitvoerveld {{node}}.{{field}} bestaat niet",
"collectionFieldType": "Verzameling {{name}}",
"deletedInvalidEdge": "Ongeldige hoek {{source}} -> {{target}} verwijderd",
"graph": "Grafiek",
"targetNodeDoesNotExist": "Ongeldige rand: doel-/invoerknooppunt {{node}} bestaat niet",
"resetToDefaultValue": "Herstel naar standaardwaarden",
"editMode": "Bewerk in Werkstroom-editor",
"showEdgeLabels": "Toon randlabels",
"showEdgeLabelsHelp": "Toon labels aan randen, waarmee de verbonden knooppunten mee worden aangegeven",
"clearWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"unableToParseFieldType": "fout bij bepalen soort veld",
"sourceNodeDoesNotExist": "Ongeldige rand: bron-/uitvoerknooppunt {{node}} bestaat niet",
"unsupportedArrayItemType": "niet-ondersteunde soort van het array-onderdeel \"{{type}}\"",
"targetNodeFieldDoesNotExist": "Ongeldige rand: doel-/invoerveld {{node}}.{{field}} bestaat niet",
"reorderLinearView": "Herorden lineaire weergave",
"newWorkflowDesc": "Een nieuwe werkstroom aanmaken?",
"collectionOrScalarFieldType": "Verzameling|scalair {{name}}",
"newWorkflow": "Nieuwe werkstroom",
"unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom",
"unsupportedAnyOfLength": "te veel union-leden ({{count}})",
"unknownOutput": "Onbekende uitvoer: {{name}}",
"viewMode": "Gebruik in lineaire weergave",
"unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref",
"unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}",
"unknownNodeType": "Onbekend soort knooppunt",
"edit": "Bewerk",
"updateAllNodes": "Werk knooppunten bij",
"allNodesUpdated": "Alle knooppunten bijgewerkt",
"nodeVersion": "Knooppuntversie",
"newWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"clearWorkflow": "Maak werkstroom leeg",
"clearWorkflowDesc": "Deze werkstroom leegmaken en met een nieuwe beginnen?",
"inputFieldTypeParseError": "Fout bij bepalen van het soort invoerveld {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Fout bij het bepalen van het soort uitvoerveld {{node}}.{{field}} ({{message}})",
"unableToExtractEnumOptions": "fout bij extraheren enumeratie-opties",
"unknownFieldType": "Soort $t(nodes.unknownField): {{type}}",
"unableToGetWorkflowVersion": "Fout bij ophalen schemaversie van werkstroom",
"betaDesc": "Deze uitvoering is in bèta. Totdat deze stabiel is kunnen er wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. We zijn van plan om deze uitvoering op lange termijn te gaan ondersteunen.",
"prototypeDesc": "Deze uitvoering is een prototype. Er kunnen wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. Deze kunnen op een willekeurig moment verwijderd worden.",
"noFieldsViewMode": "Deze werkstroom heeft geen geselecteerde velden om te tonen. Bekijk de volledige werkstroom om de waarden te configureren.",
"unableToUpdateNodes_one": "Fout bij bijwerken van {{count}} knooppunt",
"unableToUpdateNodes_other": "Fout bij bijwerken van {{count}} knooppunten"
},
"controlnet": {
"amult": "a_mult",
@ -691,9 +873,28 @@
"canny": "Canny",
"depthZoeDescription": "Genereer diepteblad via Zoe",
"hedDescription": "Herkenning van holistisch-geneste randen",
"setControlImageDimensions": "Stel afmetingen controle-afbeelding in op B/H",
"setControlImageDimensions": "Kopieer grootte naar B/H (optimaliseer voor model)",
"scribble": "Krabbel",
"maxFaces": "Max. gezichten"
"maxFaces": "Max. gezichten",
"dwOpenpose": "DW Openpose",
"depthAnything": "Depth Anything",
"base": "Basis",
"hands": "Handen",
"selectCLIPVisionModel": "Selecteer een CLIP Vision-model",
"modelSize": "Modelgrootte",
"small": "Klein",
"large": "Groot",
"resizeSimple": "Wijzig grootte (eenvoudig)",
"beginEndStepPercentShort": "Begin-/eind-%",
"depthAnythingDescription": "Genereren dieptekaart d.m.v. de techniek Depth Anything",
"face": "Gezicht",
"body": "Lichaam",
"dwOpenposeDescription": "Schatting menselijke pose d.m.v. DW Openpose",
"ipAdapterMethod": "Methode",
"full": "Volledig",
"style": "Alleen stijl",
"composition": "Alleen samenstelling",
"setControlImageDimensionsForce": "Kopieer grootte naar B/H (negeer model)"
},
"dynamicPrompts": {
"seedBehaviour": {
@ -706,7 +907,10 @@
"maxPrompts": "Max. prompts",
"promptsWithCount_one": "{{count}} prompt",
"promptsWithCount_other": "{{count}} prompts",
"dynamicPrompts": "Dynamische prompts"
"dynamicPrompts": "Dynamische prompts",
"showDynamicPrompts": "Toon dynamische prompts",
"loading": "Genereren van dynamische prompts...",
"promptsPreview": "Voorvertoning prompts"
},
"popovers": {
"noiseUseCPU": {
@ -719,7 +923,7 @@
},
"paramScheduler": {
"paragraphs": [
"De planner bepaalt hoe ruis per iteratie wordt toegevoegd aan een afbeelding of hoe een monster wordt bijgewerkt op basis van de uitvoer van een model."
"De planner gebruikt gedurende het genereringsproces."
],
"heading": "Planner"
},
@ -806,8 +1010,8 @@
},
"clipSkip": {
"paragraphs": [
"Kies hoeveel CLIP-modellagen je wilt overslaan.",
"Bepaalde modellen werken beter met bepaalde Overslaan CLIP-instellingen."
"Aantal over te slaan CLIP-modellagen.",
"Bepaalde modellen zijn beter geschikt met bepaalde Overslaan CLIP-instellingen."
],
"heading": "Overslaan CLIP"
},
@ -991,17 +1195,26 @@
"denoisingStrength": "Sterkte ontruising",
"refinermodel": "Verfijningsmodel",
"posAestheticScore": "Positieve esthetische score",
"concatPromptStyle": "Plak prompt- en stijltekst aan elkaar",
"concatPromptStyle": "Koppelen van prompt en stijl",
"loading": "Bezig met laden...",
"steps": "Stappen",
"posStylePrompt": "Positieve-stijlprompt"
"posStylePrompt": "Positieve-stijlprompt",
"freePromptStyle": "Handmatige stijlprompt",
"refinerSteps": "Aantal stappen verfijner"
},
"models": {
"noMatchingModels": "Geen overeenkomend modellen",
"loading": "bezig met laden",
"noMatchingLoRAs": "Geen overeenkomende LoRA's",
"noModelsAvailable": "Geen modellen beschikbaar",
"selectModel": "Kies een model"
"selectModel": "Kies een model",
"noLoRAsInstalled": "Geen LoRA's geïnstalleerd",
"noRefinerModelsInstalled": "Geen SDXL-verfijningsmodellen geïnstalleerd",
"defaultVAE": "Standaard-VAE",
"lora": "LoRA",
"esrganModel": "ESRGAN-model",
"addLora": "Voeg LoRA toe",
"concepts": "Concepten"
},
"boards": {
"autoAddBoard": "Voeg automatisch bord toe",
@ -1019,7 +1232,13 @@
"downloadBoard": "Download bord",
"changeBoard": "Wijzig bord",
"loading": "Bezig met laden...",
"clearSearch": "Maak zoekopdracht leeg"
"clearSearch": "Maak zoekopdracht leeg",
"deleteBoard": "Verwijder bord",
"deleteBoardAndImages": "Verwijder bord en afbeeldingen",
"deleteBoardOnly": "Verwijder alleen bord",
"deletedBoardsCannotbeRestored": "Verwijderde borden kunnen niet worden hersteld",
"movingImagesToBoard_one": "Verplaatsen van {{count}} afbeelding naar bord:",
"movingImagesToBoard_other": "Verplaatsen van {{count}} afbeeldingen naar bord:"
},
"invocationCache": {
"disable": "Schakel uit",
@ -1036,5 +1255,39 @@
"clear": "Wis",
"maxCacheSize": "Max. grootte cache",
"cacheSize": "Grootte cache"
},
"accordions": {
"generation": {
"title": "Genereren"
},
"image": {
"title": "Afbeelding"
},
"advanced": {
"title": "Geavanceerd",
"options": "$t(accordions.advanced.title) Opties"
},
"control": {
"title": "Besturing"
},
"compositing": {
"title": "Samenstellen",
"coherenceTab": "Coherentiefase",
"infillTab": "Invullen"
}
},
"hrf": {
"upscaleMethod": "Opschaalmethode",
"metadata": {
"strength": "Sterkte oplossing voor hoge resolutie",
"method": "Methode oplossing voor hoge resolutie",
"enabled": "Oplossing voor hoge resolutie ingeschakeld"
},
"hrf": "Oplossing voor hoge resolutie",
"enableHrf": "Schakel oplossing in voor hoge resolutie"
},
"prompt": {
"addPromptTrigger": "Voeg prompttrigger toe",
"compatibleEmbeddings": "Compatibele embeddings"
}
}

View File

@ -1594,7 +1594,6 @@
"deleteAll": "Удалить всё",
"addLayer": "Добавить слой",
"moveToFront": "На передний план",
"toggleVisibility": "Переключить видимость слоя",
"addPositivePrompt": "Добавить $t(common.positivePrompt)",
"addIPAdapter": "Добавить $t(common.ipAdapter)",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",

View File

@ -25,7 +25,6 @@ import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage';
import Toaster from './Toaster';
const DEFAULT_CONFIG = {};
@ -96,7 +95,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
<DeleteImageModal />
<ChangeBoardModal />
<DynamicPromptsModal />
<Toaster />
<PreselectedImage selectedImage={selectedImage} />
</ErrorBoundary>
);

View File

@ -1,5 +1,8 @@
import { Button, Flex, Heading, Link, Text, useToast } from '@invoke-ai/ui-library';
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { toast } from 'features/toast/toast';
import newGithubIssueUrl from 'new-github-issue-url';
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi';
@ -11,31 +14,39 @@ type Props = {
};
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const toast = useToast();
const { t } = useTranslation();
const isLocal = useAppSelector((s) => s.config.isLocal);
const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2);
navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``);
toast({
title: 'Error Copied',
id: 'ERROR_COPIED',
title: t('toast.errorCopied'),
});
}, [error, toast]);
}, [error, t]);
const url = useMemo(
() =>
newGithubIssueUrl({
const url = useMemo(() => {
if (isLocal) {
return newGithubIssueUrl({
user: 'invoke-ai',
repo: 'InvokeAI',
template: 'BUG_REPORT.yml',
title: `[bug]: ${error.name}: ${error.message}`,
}),
[error.message, error.name]
);
});
} else {
return 'https://support.invoke.ai/support/tickets/new';
}
}, [error.message, error.name, isLocal]);
return (
<Flex layerStyle="body" w="100vw" h="100vh" alignItems="center" justifyContent="center" p={4}>
<Flex layerStyle="first" flexDir="column" borderRadius="base" justifyContent="center" gap={8} p={16}>
<Heading>{t('common.somethingWentWrong')}</Heading>
<Flex alignItems="center" gap="2">
<Image src={InvokeLogoYellow} alt="invoke-logo" w="24px" h="24px" minW="24px" minH="24px" userSelect="none" />
<Heading fontSize="2xl">{t('common.somethingWentWrong')}</Heading>
</Flex>
<Flex
layerStyle="second"
px={8}
@ -57,7 +68,9 @@ const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
{t('common.copyError')}
</Button>
<Link href={url} isExternal>
<Button leftIcon={<PiArrowSquareOutBold />}>{t('accessibility.createIssue')}</Button>
<Button leftIcon={<PiArrowSquareOutBold />}>
{isLocal ? t('accessibility.createIssue') : t('accessibility.submitSupportTicket')}
</Button>
</Link>
</Flex>
</Flex>

View File

@ -1,44 +0,0 @@
import { useToast } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import type { MakeToastArg } from 'features/system/util/makeToast';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useEffect } from 'react';
/**
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
* @returns null
*/
const Toaster = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector((s) => s.system.toastQueue);
const toast = useToast();
useEffect(() => {
toastQueue.forEach((t) => {
toast(t);
});
toastQueue.length > 0 && dispatch(clearToastQueue());
}, [dispatch, toast, toastQueue]);
return null;
};
/**
* Returns a function that can be used to make a toast.
* @example
* const toaster = useAppToaster();
* toaster('Hello world!');
* toaster({ title: 'Hello world!', status: 'success' });
* @returns A function that can be used to make a toast.
* @see makeToast
* @see MakeToastArg
* @see UseToastOptions
*/
export const useAppToaster = () => {
const dispatch = useAppDispatch();
const toaster = useCallback((arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), [dispatch]);
return toaster;
};
export default memo(Toaster);

View File

@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client';

View File

@ -39,7 +39,6 @@ import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMidd
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
@ -99,7 +98,6 @@ addCommitStagingAreaImageListener(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addGraphExecutionStateCompleteEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);

View File

@ -8,7 +8,7 @@ import {
resetCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
@ -30,22 +30,20 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
req.reset();
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
dispatch(
addToast({
title: t('queue.cancelBatchSucceeded'),
status: 'success',
})
);
toast({
id: 'CANCEL_BATCH_SUCCEEDED',
title: t('queue.cancelBatchSucceeded'),
status: 'success',
});
}
dispatch(canvasBatchIdsReset());
} catch {
log.error('Failed to cancel canvas batches');
dispatch(
addToast({
title: t('queue.cancelBatchFailed'),
status: 'error',
})
);
toast({
id: 'CANCEL_BATCH_FAILED',
title: t('queue.cancelBatchFailed'),
status: 'error',
});
}
},
});

View File

@ -1,8 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { toast } from 'common/util/toast';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { queueApi } from 'services/api/endpoints/queue';
@ -16,18 +16,15 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const arg = action.meta.arg.originalArgs;
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
if (!toast.isActive('batch-queued')) {
toast({
id: 'batch-queued',
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
duration: 1000,
status: 'success',
});
}
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
title: t('queue.batchQueued'),
status: 'success',
description: t('queue.batchQueuedDesc', {
count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
});
},
});
@ -40,9 +37,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
if (!response) {
toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'),
status: 'error',
description: 'Unknown Error',
description: t('common.unknownError'),
});
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
return;
@ -52,7 +50,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
if (result.success) {
result.data.data.detail.map((e) => {
toast({
id: 'batch-failed-to-queue',
id: 'QUEUE_BATCH_FAILED',
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: truncate(
@ -64,9 +62,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
});
} else if (response.status !== 403) {
toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'),
description: t('common.unknownError'),
status: 'error',
description: t('common.unknownError'),
});
}
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));

View File

@ -1,8 +1,7 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'common/util/toast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
@ -28,7 +27,6 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// Show the response message if it exists, otherwise show the default message
description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'),
duration: null,
isClosable: true,
});
},
});
@ -40,9 +38,9 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// There isn't any toast to update if we get this event.
toast({
id: 'BULK_DOWNLOAD_REQUEST_FAILED',
title: t('gallery.bulkDownloadRequestFailed'),
status: 'success',
isClosable: true,
status: 'error',
});
},
});
@ -65,7 +63,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
const toastOptions: UseToastOptions = {
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
@ -77,14 +75,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
/>
),
duration: null,
isClosable: true,
};
if (toast.isActive(bulk_download_item_name)) {
toast.update(bulk_download_item_name, toastOptions);
} else {
toast(toastOptions);
}
});
},
});
@ -95,20 +86,13 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
const { bulk_download_item_name } = action.payload.data;
const toastOptions: UseToastOptions = {
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: action.payload.data.error,
duration: null,
isClosable: true,
};
if (toast.isActive(bulk_download_item_name)) {
toast.update(bulk_download_item_name, toastOptions);
} else {
toast(toastOptions);
}
});
},
});
};

View File

@ -2,14 +2,14 @@ import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasCopiedToClipboard,
effect: async (action, { dispatch, getState }) => {
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState();
@ -19,22 +19,20 @@ export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartLi
copyBlobToClipboard(blob);
} catch (err) {
moduleLog.error(String(err));
dispatch(
addToast({
title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'),
status: 'error',
})
);
toast({
id: 'CANVAS_COPY_FAILED',
title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'),
status: 'error',
});
return;
}
dispatch(
addToast({
title: t('toast.canvasCopiedClipboard'),
status: 'success',
})
);
toast({
id: 'CANVAS_COPY_SUCCEEDED',
title: t('toast.canvasCopiedClipboard'),
status: 'success',
});
},
});
};

View File

@ -3,13 +3,13 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasDownloadedAsImage,
effect: async (action, { dispatch, getState }) => {
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState();
@ -18,18 +18,17 @@ export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartLi
blob = await getBaseLayerBlob(state);
} catch (err) {
moduleLog.error(String(err));
dispatch(
addToast({
title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'),
status: 'error',
})
);
toast({
id: 'CANVAS_DOWNLOAD_FAILED',
title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'),
status: 'error',
});
return;
}
downloadBlob(blob, 'canvas.png');
dispatch(addToast({ title: t('toast.canvasDownloaded'), status: 'success' }));
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
},
});
};

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -20,13 +20,12 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
blob = await getBaseLayerBlob(state, true);
} catch (err) {
log.error(String(err));
dispatch(
addToast({
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
})
);
toast({
id: 'PROBLEM_SAVING_CANVAS',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
@ -43,7 +42,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
crop_visible: false,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: t('toast.canvasSentControlnetAssets') },
title: t('toast.canvasSentControlnetAssets'),
},
})
).unwrap();

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -29,13 +29,12 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL
if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'),
status: 'error',
})
);
toast({
id: 'PROBLEM_SAVING_MASK',
title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'),
status: 'error',
});
return;
}
@ -52,7 +51,7 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: t('toast.maskSavedAssets') },
title: t('toast.maskSavedAssets'),
},
})
);

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -30,13 +30,12 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'),
status: 'error',
})
);
toast({
id: 'PROBLEM_IMPORTING_MASK',
title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'),
status: 'error',
});
return;
}
@ -53,7 +52,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
crop_visible: false,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: t('toast.maskSentControlnetAssets') },
title: t('toast.maskSentControlnetAssets'),
},
})
).unwrap();

View File

@ -4,7 +4,7 @@ import { canvasMerged } from 'features/canvas/store/actions';
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -17,13 +17,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
})
);
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
@ -31,13 +30,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer');
dispatch(
addToast({
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
})
);
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
@ -54,7 +52,7 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
is_intermediate: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: t('toast.canvasMerged') },
title: t('toast.canvasMerged'),
},
})
).unwrap();

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -19,13 +19,12 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
blob = await getBaseLayerBlob(state);
} catch (err) {
log.error(String(err));
dispatch(
addToast({
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
})
);
toast({
id: 'CANVAS_SAVE_FAILED',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
@ -42,7 +41,7 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: t('toast.canvasSavedGallery') },
title: t('toast.canvasSavedGallery'),
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),

View File

@ -14,7 +14,7 @@ import {
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
@ -174,12 +174,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
}
}
dispatch(
addToast({
title: t('queue.graphFailedToQueue'),
status: 'error',
})
);
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
} finally {
req.reset();

View File

@ -10,7 +10,7 @@ import {
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
@ -108,12 +108,11 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
}
}
dispatch(
addToast({
title: t('queue.graphFailedToQueue'),
status: 'error',
})
);
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
},
});

View File

@ -1,4 +1,3 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -14,7 +13,7 @@ import {
} from 'features/controlLayers/store/controlLayersSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
@ -42,16 +41,17 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
status: 'success',
};
} as const;
// default action - just upload and alert user
if (postUploadAction?.type === 'TOAST') {
const { toastOptions } = postUploadAction;
if (!autoAddBoardId || autoAddBoardId === 'none') {
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
const title = postUploadAction.title || DEFAULT_UPLOADED_TOAST.title;
toast({ ...DEFAULT_UPLOADED_TOAST, title });
} else {
// Add this image to the board
dispatch(
@ -70,24 +70,20 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
? `${t('toast.addedToBoard')} ${board.board_name}`
: `${t('toast.addedToBoard')} ${autoAddBoardId}`;
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description,
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description,
});
}
return;
}
if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') {
dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state)));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setAsCanvasInitialImage'),
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setAsCanvasInitialImage'),
});
return;
}
@ -105,68 +101,56 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
controlImage: imageDTO.image_name,
})
);
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
return;
}
if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(caLayerImageChanged({ layerId, imageDTO }));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') {
const { layerId, ipAdapterId } = postUploadAction;
dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO }));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(iiLayerImageChanged({ layerId, imageDTO }));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: `${t('toast.setNodeField')} ${fieldName}`,
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: `${t('toast.setNodeField')} ${fieldName}`,
});
return;
}
},
@ -174,7 +158,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchRejected,
effect: (action, { dispatch }) => {
effect: (action) => {
const log = logger('images');
const sanitizedData = {
arg: {
@ -183,13 +167,11 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
},
};
log.error({ ...sanitizedData }, 'Image upload failed');
dispatch(
addToast({
title: t('toast.imageUploadFailed'),
description: action.error.message,
status: 'error',
})
);
toast({
title: t('toast.imageUploadFailed'),
description: action.error.message,
status: 'error',
});
},
});
};

View File

@ -8,8 +8,7 @@ import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
@ -60,16 +59,14 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
});
if (modelsCleared > 0) {
dispatch(
addToast(
makeToast({
title: t('toast.baseModelChangedCleared', {
count: modelsCleared,
}),
status: 'warning',
})
)
);
toast({
id: 'BASE_MODEL_CHANGED',
title: t('toast.baseModelChanged'),
description: t('toast.baseModelChangedCleared', {
count: modelsCleared,
}),
status: 'warning',
});
}
}

View File

@ -19,8 +19,7 @@ import {
isParameterWidth,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@ -109,7 +108,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
toast({ id: 'PARAMETER_SET', title: t('toast.parameterSet', { parameter: 'Default settings' }) });
}
},
});

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
@ -11,7 +12,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
startAppListening({
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(action.payload, `Generator progress`);
log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketGraphExecutionStateComplete } from 'services/events/actions';
const log = logger('socketio');
export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketGraphExecutionStateComplete,
effect: (action) => {
log.debug(action.payload, 'Session complete');
},
});
};

View File

@ -29,11 +29,11 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation_type})`);
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`);
const { result, invocation_source_id } = data;
// This complete event has an associated image output
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation_type)) {
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image;
const { canvas, gallery } = getState();

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions';
@ -11,14 +12,18 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
startAppListening({
actionCreator: socketInvocationError,
effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.invocation_type})`);
const { invocation_source_id } = action.payload.data;
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
log.error(parseify(action.payload), `Invocation error (${invocation.type})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;
nes.progress = null;
nes.progressImage = null;
nes.error = {
error_type,
error_message,
error_traceback,
};
upsertExecutionState(nes.nodeId, nes);
}
},

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
@ -11,7 +12,7 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.invocation_type})`);
log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {

View File

@ -3,6 +3,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
@ -12,10 +14,21 @@ const log = logger('socketio');
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => {
effect: async (action, { dispatch, getState }) => {
// we've got new status for the queue item, batch and queue
const { item_id, status, started_at, updated_at, error, completed_at, batch_status, queue_status } =
action.payload.data;
const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
} = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
@ -28,8 +41,10 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
status,
started_at,
updated_at: updated_at ?? undefined,
error,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
});
})
@ -61,7 +76,7 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
])
);
if (['in_progress'].includes(action.payload.data.status)) {
if (status === 'in_progress') {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
@ -74,6 +89,25 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
}
},
});

View File

@ -1,6 +1,6 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -29,15 +29,14 @@ export const addStagingAreaImageSavedListener = (startAppListening: AppStartList
})
);
}
dispatch(addToast({ title: t('toast.imageSaved'), status: 'success' }));
toast({ id: 'IMAGE_SAVED', title: t('toast.imageSaved'), status: 'success' });
} catch (error) {
dispatch(
addToast({
title: t('toast.imageSavingFailed'),
description: (error as Error)?.message,
status: 'error',
})
);
toast({
id: 'IMAGE_SAVE_FAILED',
title: t('toast.imageSavingFailed'),
description: (error as Error)?.message,
status: 'error',
});
}
},
});

View File

@ -5,8 +5,7 @@ import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartListening) => {
@ -50,24 +49,18 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
count: unableToUpdateCount,
})
);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToUpdateNodes', {
count: unableToUpdateCount,
}),
})
)
);
toast({
id: 'UNABLE_TO_UPDATE_NODES',
title: t('nodes.unableToUpdateNodes', {
count: unableToUpdateCount,
}),
});
} else {
dispatch(
addToast(
makeToast({
title: t('nodes.allNodesUpdated'),
status: 'success',
})
)
);
toast({
id: 'ALL_NODES_UPDATED',
title: t('nodes.allNodesUpdated'),
status: 'success',
});
}
},
});

View File

@ -4,7 +4,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize';
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph';
import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale';
import { addToast } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
@ -29,12 +29,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
{ imageDTO },
t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce
);
dispatch(
addToast({
title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce
status: 'error',
})
);
toast({
id: 'NOT_ALLOWED_TO_UPSCALE',
title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce
status: 'error',
});
return;
}
@ -65,12 +64,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
if (error instanceof Object && 'status' in error && error.status === 403) {
return;
} else {
dispatch(
addToast({
title: t('queue.graphFailedToQueue'),
status: 'error',
})
);
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
}
},

View File

@ -8,23 +8,23 @@ import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates);
return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates);
return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
} else {
throw new Error('No workflow or graph provided');
}
@ -33,13 +33,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
effect: async (action, { dispatch }) => {
const log = logger('nodes');
const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get();
try {
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
const { workflow, warnings } = await getWorkflow(data, nodeTemplates);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
@ -48,23 +48,18 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
dispatch(workflowLoaded(workflow));
if (!warnings.length) {
dispatch(
addToast(
makeToast({
title: t('toast.workflowLoaded'),
status: 'success',
})
)
);
toast({
id: 'WORKFLOW_LOADED',
title: t('toast.workflowLoaded'),
status: 'success',
});
} else {
dispatch(
addToast(
makeToast({
title: t('toast.loadedWithWarnings'),
status: 'warning',
})
)
);
toast({
id: 'WORKFLOW_LOADED',
title: t('toast.loadedWithWarnings'),
status: 'warning',
});
warnings.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
@ -77,54 +72,42 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions
log.error({ error: parseify(e) }, e.message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
})
)
);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
});
} else if (e instanceof WorkflowMigrationError) {
// There was a problem migrating the workflow to the latest version
log.error({ error: parseify(e) }, e.message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
})
)
);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
});
} else if (e instanceof z.ZodError) {
// There was a problem validating the workflow itself
const { message } = fromZodError(e, {
prefix: t('nodes.workflowValidation'),
});
log.error({ error: parseify(e) }, message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: message,
})
)
);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: message,
});
} else {
// Some other error occurred
log.error({ error: parseify(e) }, t('nodes.unknownErrorValidatingWorkflow'));
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: t('nodes.unknownErrorValidatingWorkflow'),
})
)
);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: t('nodes.unknownErrorValidatingWorkflow'),
});
}
}
},

View File

@ -74,6 +74,7 @@ export type AppConfig = {
maxUpscalePixels?: number;
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

View File

@ -1,11 +1,10 @@
import { useAppToaster } from 'app/components/Toaster';
import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const imageUrlToBlob = useImageUrlToBlob();
@ -16,12 +15,11 @@ export const useCopyImageToClipboard = () => {
const copyImageToClipboard = useCallback(
async (image_url: string) => {
if (!isClipboardAPIAvailable) {
toaster({
toast({
id: 'PROBLEM_COPYING_IMAGE',
title: t('toast.problemCopyingImage'),
description: "Your browser doesn't support the Clipboard API.",
status: 'error',
duration: 2500,
isClosable: true,
});
}
try {
@ -33,23 +31,21 @@ export const useCopyImageToClipboard = () => {
copyBlobToClipboard(blob);
toaster({
toast({
id: 'IMAGE_COPIED',
title: t('toast.imageCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
} catch (err) {
toaster({
toast({
id: 'PROBLEM_COPYING_IMAGE',
title: t('toast.problemCopyingImage'),
description: String(err),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[imageUrlToBlob, isClipboardAPIAvailable, t, toaster]
[imageUrlToBlob, isClipboardAPIAvailable, t]
);
return { isClipboardAPIAvailable, copyImageToClipboard };

View File

@ -1,13 +1,12 @@
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { $authToken } from 'app/store/nanostores/authToken';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDownloaded } from 'features/gallery/store/actions';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const useDownloadImage = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const dispatch = useAppDispatch();
const authToken = useStore($authToken);
@ -37,16 +36,15 @@ export const useDownloadImage = () => {
window.URL.revokeObjectURL(url);
dispatch(imageDownloaded());
} catch (err) {
toaster({
toast({
id: 'PROBLEM_DOWNLOADING_IMAGE',
title: t('toast.problemDownloadingImage'),
description: String(err),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[t, toaster, dispatch, authToken]
[t, dispatch, authToken]
);
return { downloadImage };

View File

@ -1,6 +1,6 @@
import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { toast } from 'features/toast/toast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone';
@ -26,7 +26,6 @@ const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (ac
export const useFullscreenDropzone = () => {
const { t } = useTranslation();
const toaster = useAppToaster();
const postUploadAction = useAppSelector(selectPostUploadAction);
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
@ -37,13 +36,14 @@ export const useFullscreenDropzone = () => {
(rejection: FileRejection) => {
setIsHandlingUpload(true);
toaster({
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description: rejection.errors.map((error) => error.message).join('\n'),
status: 'error',
});
},
[t, toaster]
[t]
);
const fileAcceptedCallback = useCallback(
@ -62,7 +62,8 @@ export const useFullscreenDropzone = () => {
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 1) {
toaster({
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description: t('toast.uploadFailedInvalidUploadDesc'),
status: 'error',
@ -78,7 +79,7 @@ export const useFullscreenDropzone = () => {
fileAcceptedCallback(file);
});
},
[t, toaster, fileAcceptedCallback, fileRejectionCallback]
[t, fileAcceptedCallback, fileRejectionCallback]
);
const onDragOver = useCallback(() => {

View File

@ -1,6 +0,0 @@
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
export const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});

View File

@ -4,7 +4,7 @@ import { CALayerControlAdapterWrapper } from 'features/controlLayers/components/
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react';
@ -26,7 +26,7 @@ export const CALayer = memo(({ layerId }: Props) => {
return (
<LayerWrapper onClick={onClick} borderColor={isSelected ? 'base.400' : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} />
<LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="control_adapter_layer" />
<Spacer />
<CALayerOpacity layerId={layerId} />

View File

@ -5,7 +5,7 @@ import { InitialImagePreview } from 'features/controlLayers/components/IILayer/I
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import {
iiLayerDenoisingStrengthChanged,
@ -66,7 +66,7 @@ export const IILayer = memo(({ layerId }: Props) => {
return (
<LayerWrapper onClick={onClick} borderColor={layer.isSelected ? 'base.400' : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} />
<LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="initial_image_layer" />
<Spacer />
<IILayerOpacity layerId={layerId} />

View File

@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IPALayerIPAdapterWrapper } from 'features/controlLayers/components/IPALayer/IPALayerIPAdapterWrapper';
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import { layerSelected, selectIPALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react';
@ -22,7 +22,7 @@ export const IPALayer = memo(({ layerId }: Props) => {
return (
<LayerWrapper onClick={onClick} borderColor={isSelected ? 'base.400' : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} />
<LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="ip_adapter_layer" />
<Spacer />
<LayerDeleteButton layerId={layerId} />

View File

@ -1,8 +1,8 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPropagation } from 'common/util/stopPropagation';
import { useLayerIsVisible } from 'features/controlLayers/hooks/layerStateHooks';
import { layerVisibilityToggled } from 'features/controlLayers/store/controlLayersSlice';
import { useLayerIsEnabled } from 'features/controlLayers/hooks/layerStateHooks';
import { layerIsEnabledToggled } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold } from 'react-icons/pi';
@ -11,21 +11,21 @@ type Props = {
layerId: string;
};
export const LayerVisibilityToggle = memo(({ layerId }: Props) => {
export const LayerIsEnabledToggle = memo(({ layerId }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isVisible = useLayerIsVisible(layerId);
const isEnabled = useLayerIsEnabled(layerId);
const onClick = useCallback(() => {
dispatch(layerVisibilityToggled(layerId));
dispatch(layerIsEnabledToggled(layerId));
}, [dispatch, layerId]);
return (
<IconButton
size="sm"
aria-label={t('controlLayers.toggleVisibility')}
tooltip={t('controlLayers.toggleVisibility')}
aria-label={t(isEnabled ? 'common.enabled' : 'common.disabled')}
tooltip={t(isEnabled ? 'common.enabled' : 'common.disabled')}
variant="outline"
icon={isVisible ? <PiCheckBold /> : undefined}
icon={isEnabled ? <PiCheckBold /> : undefined}
onClick={onClick}
colorScheme="base"
onDoubleClick={stopPropagation} // double click expands the layer
@ -33,4 +33,4 @@ export const LayerVisibilityToggle = memo(({ layerId }: Props) => {
);
});
LayerVisibilityToggle.displayName = 'LayerVisibilityToggle';
LayerIsEnabledToggle.displayName = 'LayerVisibilityToggle';

View File

@ -6,7 +6,7 @@ import { AddPromptButtons } from 'features/controlLayers/components/AddPromptBut
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import {
isRegionalGuidanceLayer,
@ -55,7 +55,7 @@ export const RGLayer = memo(({ layerId }: Props) => {
return (
<LayerWrapper onClick={onClick} borderColor={isSelected ? color : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} />
<LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="regional_guidance_layer" />
<Spacer />
{autoNegative === 'invert' && (

View File

@ -45,7 +45,6 @@ export const RGLayerNegativePrompt = memo(({ layerId }: Props) => {
variant="darkFilled"
paddingRight={30}
fontSize="sm"
spellCheck={false}
/>
<PromptOverlayButtonWrapper>
<RGLayerPromptDeleteButton layerId={layerId} polarity="negative" />

View File

@ -45,7 +45,6 @@ export const RGLayerPositivePrompt = memo(({ layerId }: Props) => {
variant="darkFilled"
paddingRight={30}
minH={28}
spellCheck={false}
/>
<PromptOverlayButtonWrapper>
<RGLayerPromptDeleteButton layerId={layerId} polarity="positive" />

View File

@ -39,7 +39,7 @@ export const useLayerNegativePrompt = (layerId: string) => {
return prompt;
};
export const useLayerIsVisible = (layerId: string) => {
export const useLayerIsEnabled = (layerId: string) => {
const selectLayer = useMemo(
() =>
createSelector(selectControlLayersSlice, (controlLayers) => {

View File

@ -139,7 +139,7 @@ export const controlLayersSlice = createSlice({
layerSelected: (state, action: PayloadAction<string>) => {
exclusivelySelectLayer(state, action.payload);
},
layerVisibilityToggled: (state, action: PayloadAction<string>) => {
layerIsEnabledToggled: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (layer) {
layer.isEnabled = !layer.isEnabled;
@ -616,12 +616,24 @@ export const controlLayersSlice = createSlice({
iiLayerAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload;
// Retain opacity and denoising strength of existing initial image layer if exists
let opacity = 1;
let denoisingStrength = 0.75;
const iiLayer = state.layers.find((l) => l.id === layerId);
if (iiLayer) {
assert(isInitialImageLayer(iiLayer));
opacity = iiLayer.opacity;
denoisingStrength = iiLayer.denoisingStrength;
}
// Highlander! There can be only one!
state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true));
const layer: InitialImageLayer = {
id: layerId,
type: 'initial_image_layer',
opacity: 1,
opacity,
x: 0,
y: 0,
bbox: null,
@ -629,7 +641,7 @@ export const controlLayersSlice = createSlice({
isEnabled: true,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
isSelected: true,
denoisingStrength: 0.75,
denoisingStrength,
};
state.layers.push(layer);
exclusivelySelectLayer(state, layer.id);
@ -779,7 +791,7 @@ class LayerColors {
export const {
// Any Layer Type
layerSelected,
layerVisibilityToggled,
layerIsEnabledToggled,
layerTranslated,
layerBboxChanged,
layerReset,

View File

@ -1,6 +1,5 @@
import { Flex, MenuDivider, MenuItem, Spinner } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
@ -14,6 +13,7 @@ import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/ac
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es';
@ -46,7 +46,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const optimalDimension = useAppSelector(selectOptimalDimension);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('canvas');
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();
@ -86,13 +85,12 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
});
dispatch(setInitialCanvasImage(imageDTO, optimalDimension));
toaster({
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
}, [dispatch, imageDTO, t, toaster, optimalDimension]);
}, [dispatch, imageDTO, t, optimalDimension]);
const handleChangeBoard = useCallback(() => {
dispatch(imagesToChangeSelected([imageDTO]));

View File

@ -1,6 +1,5 @@
import { IconButton } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
@ -14,7 +13,6 @@ export const ToggleMetadataViewerButton = memo(() => {
const dispatch = useAppDispatch();
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const toaster = useAppToaster();
const { t } = useTranslation();
const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken);
@ -24,7 +22,7 @@ export const ToggleMetadataViewerButton = memo(() => {
[dispatch, shouldShowImageDetails]
);
useHotkeys('i', toggleMetadataViewer, { enabled: Boolean(imageDTO) }, [imageDTO, shouldShowImageDetails, toaster]);
useHotkeys('i', toggleMetadataViewer, { enabled: Boolean(imageDTO) }, [imageDTO, shouldShowImageDetails]);
return (
<IconButton

View File

@ -53,7 +53,7 @@ export const useImageActions = (image_name?: string) => {
const recallSeed = useCallback(() => {
handlers.seed.parse(metadata).then((seed) => {
handlers.seed.recall && handlers.seed.recall(seed);
handlers.seed.recall && handlers.seed.recall(seed, true);
});
}, [metadata]);

View File

@ -1,5 +1,4 @@
import { objectKeys } from 'common/util/objectKeys';
import { toast } from 'common/util/toast';
import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
@ -15,6 +14,7 @@ import type {
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { assert } from 'tsafe';
@ -89,23 +89,23 @@ const renderLayersValue: MetadataRenderValueFunc<Layer[]> = async (layers) => {
return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`;
};
const parameterSetToast = (parameter: string, description?: string) => {
const parameterSetToast = (parameter: string) => {
toast({
title: t('toast.parameterSet', { parameter }),
description,
id: 'PARAMETER_SET',
title: t('toast.parameterSet'),
description: t('toast.parameterSetDesc', { parameter }),
status: 'info',
duration: 2500,
isClosable: true,
});
};
const parameterNotSetToast = (parameter: string, description?: string) => {
const parameterNotSetToast = (parameter: string, message?: string) => {
toast({
title: t('toast.parameterNotSet', { parameter }),
description,
id: 'PARAMETER_NOT_SET',
title: t('toast.parameterNotSet'),
description: message
? t('toast.parameterNotSetDescWithMessage', { parameter, message })
: t('toast.parameterNotSetDesc', { parameter }),
status: 'warning',
duration: 2500,
isClosable: true,
});
};
@ -458,7 +458,18 @@ export const parseAndRecallAllMetadata = async (
});
})
);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('toast.parameters'));
toast({
id: 'PARAMETER_SET',
title: t('toast.parametersSet'),
status: 'info',
});
} else {
toast({
id: 'PARAMETER_SET',
title: t('toast.parametersNotSet'),
status: 'warning',
});
}
};

View File

@ -0,0 +1,48 @@
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type InstallModelArg = {
source: string;
inplace?: boolean;
onSuccess?: () => void;
onError?: (error: unknown) => void;
};
export const useInstallModel = () => {
const { t } = useTranslation();
const [_installModel, request] = useInstallModelMutation();
const installModel = useCallback(
({ source, inplace, onSuccess, onError }: InstallModelArg) => {
_installModel({ source, inplace })
.unwrap()
.then((_) => {
if (onSuccess) {
onSuccess();
}
toast({
id: 'MODEL_INSTALL_QUEUED',
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (onError) {
onError(error);
}
if (error) {
toast({
id: 'MODEL_INSTALL_QUEUE_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
},
[_installModel, t]
);
return [installModel, request] as const;
};

View File

@ -17,7 +17,11 @@ export const useStarterModelsToast = () => {
useEffect(() => {
if (toast.isActive(TOAST_ID)) {
return;
if (mainModels.length === 0) {
return;
} else {
toast.close(TOAST_ID);
}
}
if (data && mainModels.length === 0 && !didToast && isEnabled) {
toast({

View File

@ -1,11 +1,9 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults';
@ -14,50 +12,19 @@ export const HuggingFaceForm = () => {
const [displayResults, setDisplayResults] = useState(false);
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModelMutation();
const handleInstallModel = useCallback(
(source: string) => {
installModel({ source })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t]
);
const [installModel] = useInstallModel();
const getModels = useCallback(async () => {
_getHuggingFaceModels(huggingFaceRepo)
.unwrap()
.then((response) => {
if (response.is_diffusers) {
handleInstallModel(huggingFaceRepo);
installModel({ source: huggingFaceRepo });
setDisplayResults(false);
} else if (response.urls?.length === 1 && response.urls[0]) {
handleInstallModel(response.urls[0]);
installModel({ source: response.urls[0] });
setDisplayResults(false);
} else {
setDisplayResults(true);
@ -66,7 +33,7 @@ export const HuggingFaceForm = () => {
.catch((error) => {
setErrorMessage(error.data.detail || '');
});
}, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]);
}, [_getHuggingFaceModels, installModel, huggingFaceRepo]);
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setHuggingFaceRepo(e.target.value);

View File

@ -1,47 +1,20 @@
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: string;
};
export const HuggingFaceResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const [installModel] = useInstallModel();
const handleInstall = useCallback(() => {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [installModel, result, dispatch, t]);
const onClick = useCallback(() => {
installModel({ source: result });
}, [installModel, result]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
@ -51,7 +24,7 @@ export const HuggingFaceResultItem = ({ result }: Props) => {
{result}
</Text>
</Flex>
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
</Flex>
);
};

View File

@ -8,15 +8,12 @@ import {
InputGroup,
InputRightElement,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
@ -27,9 +24,8 @@ type HuggingFaceResultsProps = {
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const [installModel] = useInstallModel();
const filteredResults = useMemo(() => {
return results.filter((result) => {
@ -46,34 +42,11 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
setSearchTerm('');
}, []);
const handleAddAll = useCallback(() => {
const onClickAddAll = useCallback(() => {
for (const result of filteredResults) {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
installModel({ source: result });
}
}, [filteredResults, installModel, dispatch, t]);
}, [filteredResults, installModel]);
return (
<>
@ -82,7 +55,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
<Flex justifyContent="space-between" alignItems="center">
<Heading size="sm">{t('modelManager.availableModels')}</Heading>
<Flex alignItems="center" gap={3}>
<Button size="sm" onClick={handleAddAll} isDisabled={results.length === 0} flexShrink={0}>
<Button size="sm" onClick={onClickAddAll} isDisabled={results.length === 0} flexShrink={0}>
{t('modelManager.installAll')}
</Button>
<InputGroup w={64} size="xs">

View File

@ -1,12 +1,9 @@
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { t } from 'i18next';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type SimpleImportModelConfig = {
location: string;
@ -14,9 +11,7 @@ type SimpleImportModelConfig = {
};
export const InstallModelForm = () => {
const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation();
const [installModel, { isLoading }] = useInstallModel();
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
defaultValues: {
@ -26,40 +21,22 @@ export const InstallModelForm = () => {
mode: 'onChange',
});
const resetForm = useCallback(() => reset(undefined, { keepValues: true }), [reset]);
const onSubmit = useCallback<SubmitHandler<SimpleImportModelConfig>>(
(values) => {
if (!values?.location) {
return;
}
installModel({ source: values.location, inplace: values.inplace })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
reset(undefined, { keepValues: true });
})
.catch((error) => {
reset(undefined, { keepValues: true });
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
installModel({
source: values.location,
inplace: values.inplace,
onSuccess: resetForm,
onError: resetForm,
});
},
[dispatch, reset, installModel]
[installModel, resetForm]
);
return (

View File

@ -1,8 +1,6 @@
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
@ -10,8 +8,6 @@ import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } fro
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
export const ModelInstallQueue = () => {
const dispatch = useAppDispatch();
const { data } = useListModelInstallsQuery();
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
@ -20,28 +16,22 @@ export const ModelInstallQueue = () => {
_pruneCompletedModelInstalls()
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.prunedQueue'),
status: 'success',
})
)
);
toast({
id: 'MODEL_INSTALL_QUEUE_PRUNED',
title: t('toast.prunedQueue'),
status: 'success',
});
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
toast({
id: 'MODEL_INSTALL_QUEUE_PRUNE_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
}, [_pruneCompletedModelInstalls, dispatch]);
}, [_pruneCompletedModelInstalls]);
const pruneAvailable = useMemo(() => {
return data?.some(

View File

@ -1,7 +1,5 @@
import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@ -29,7 +27,6 @@ const formatBytes = (bytes: number) => {
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const { installJob } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useCancelModelInstallMutation();
@ -37,28 +34,22 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
deleteImportModel(installJob.id)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
toast({
id: 'MODEL_INSTALL_CANCELED',
title: t('toast.modelImportCanceled'),
status: 'success',
});
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
toast({
id: 'MODEL_INSTALL_CANCEL_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
}, [deleteImportModel, installJob, dispatch]);
}, [deleteImportModel, installJob]);
const sourceLocation = useMemo(() => {
switch (installJob.source.type) {

View File

@ -11,15 +11,13 @@ import {
InputGroup,
InputRightElement,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEvent, ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanFolderResultItem';
@ -30,9 +28,8 @@ type ScanModelResultsProps = {
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const dispatch = useAppDispatch();
const [inplace, setInplace] = useState(true);
const [installModel] = useInstallModelMutation();
const [installModel] = useInstallModel();
const filteredResults = useMemo(() => {
return results.filter((result) => {
@ -58,61 +55,15 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
if (result.is_installed) {
continue;
}
installModel({ source: result.path, inplace })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
installModel({ source: result.path, inplace });
}
}, [filteredResults, installModel, inplace, dispatch, t]);
}, [filteredResults, installModel, inplace]);
const handleInstallOne = useCallback(
(source: string) => {
installModel({ source, inplace })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
installModel({ source, inplace });
},
[installModel, inplace, dispatch, t]
[installModel, inplace]
);
return (

View File

@ -1,20 +1,16 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: GetStarterModelsResponse[number];
};
export const StarterModelsResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const allSources = useMemo(() => {
const _allSources = [result.source];
if (result.dependencies) {
@ -22,36 +18,13 @@ export const StarterModelsResultItem = ({ result }: Props) => {
}
return _allSources;
}, [result]);
const [installModel] = useInstallModelMutation();
const [installModel] = useInstallModel();
const handleQuickAdd = useCallback(() => {
const onClick = useCallback(() => {
for (const source of allSources) {
installModel({ source })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
installModel({ source });
}
}, [allSources, installModel, dispatch, t]);
}, [allSources, installModel]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
@ -67,7 +40,7 @@ export const StarterModelsResultItem = ({ result }: Props) => {
{result.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
)}
</Box>
</Flex>

View File

@ -4,8 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -53,25 +52,19 @@ const ModelListItem = (props: ModelListItemProps) => {
deleteModel({ key: model.key })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
})
)
);
toast({
id: 'MODEL_DELETED',
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
});
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
})
)
);
toast({
id: 'MODEL_DELETE_FAILED',
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
});
}
});
dispatch(setSelectedModelKey(null));

View File

@ -1,10 +1,9 @@
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
@ -19,7 +18,6 @@ export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
export const ControlNetOrT2IAdapterDefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
useControlNetOrT2IAdapterDefaultSettings(selectedModelKey);
@ -46,30 +44,24 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
})
)
);
toast({
id: 'DEFAULT_SETTINGS_SAVED',
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
});
reset(data);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
toast({
id: 'DEFAULT_SETTINGS_SAVE_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
},
[selectedModelKey, dispatch, reset, updateModel, t]
[selectedModelKey, reset, updateModel, t]
);
if (isLoadingDefaultSettings) {

View File

@ -1,8 +1,6 @@
import { Box, Button, Flex, Icon, IconButton, Image, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { typedMemo } from 'common/util/typedMemo';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { useCallback, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
@ -15,7 +13,6 @@ type Props = {
};
const ModelImageUpload = ({ model_key, model_image }: Props) => {
const dispatch = useAppDispatch();
const [image, setImage] = useState<string | null>(model_image || null);
const { t } = useTranslation();
@ -34,27 +31,21 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
.unwrap()
.then(() => {
setImage(URL.createObjectURL(file));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageUpdated'),
status: 'success',
})
)
);
toast({
id: 'MODEL_IMAGE_UPDATED',
title: t('modelManager.modelImageUpdated'),
status: 'success',
});
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageUpdateFailed'),
status: 'error',
})
)
);
.catch(() => {
toast({
id: 'MODEL_IMAGE_UPDATE_FAILED',
title: t('modelManager.modelImageUpdateFailed'),
status: 'error',
});
});
},
[dispatch, model_key, t, updateModelImage]
[model_key, t, updateModelImage]
);
const handleResetImage = useCallback(() => {
@ -65,26 +56,20 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
deleteModelImage(model_key)
.unwrap()
.then(() => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageDeleted'),
status: 'success',
})
)
);
toast({
id: 'MODEL_IMAGE_DELETED',
title: t('modelManager.modelImageDeleted'),
status: 'success',
});
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageDeleteFailed'),
status: 'error',
})
)
);
.catch(() => {
toast({
id: 'MODEL_IMAGE_DELETE_FAILED',
title: t('modelManager.modelImageDeleteFailed'),
status: 'error',
});
});
}, [dispatch, model_key, t, deleteModelImage]);
}, [model_key, t, deleteModelImage]);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },

View File

@ -1,11 +1,10 @@
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
@ -39,7 +38,6 @@ export type MainModelDefaultSettingsFormData = {
export const MainModelDefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const {
defaultSettingsDefaults,
@ -76,30 +74,24 @@ export const MainModelDefaultSettings = () => {
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
})
)
);
toast({
id: 'DEFAULT_SETTINGS_SAVED',
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
});
reset(data);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
toast({
id: 'DEFAULT_SETTINGS_SAVE_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
},
[selectedModelKey, dispatch, reset, updateModel, t]
[selectedModelKey, reset, updateModel, t]
);
if (isLoadingDefaultSettings) {

View File

@ -4,8 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
@ -47,25 +46,19 @@ export const Model = () => {
.then((payload) => {
form.reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
toast({
id: 'MODEL_UPDATED',
title: t('modelManager.modelUpdated'),
status: 'success',
});
})
.catch((_) => {
form.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
toast({
id: 'MODEL_UPDATE_FAILED',
title: t('modelManager.modelUpdateFailed'),
status: 'error',
});
});
},
[dispatch, data?.key, form, t, updateModel]

View File

@ -9,9 +9,7 @@ import {
useDisclosure,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models';
@ -22,7 +20,6 @@ interface ModelConvertProps {
export const ModelConvertButton = (props: ModelConvertProps) => {
const { modelKey } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data } = useGetModelConfigQuery(modelKey ?? skipToken);
const [convertModel, { isLoading }] = useConvertModelMutation();
@ -33,38 +30,26 @@ export const ModelConvertButton = (props: ModelConvertProps) => {
return;
}
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
status: 'info',
})
)
);
const toastId = `CONVERTING_MODEL_${data.key}`;
toast({
id: toastId,
title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
status: 'info',
});
convertModel(data?.key)
.unwrap()
.then(() => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${data?.name}`,
status: 'success',
})
)
);
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' });
})
.catch(() => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
status: 'error',
})
)
);
toast({
id: toastId,
title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
status: 'error',
});
});
}, [data, isLoading, dispatch, t, convertModel]);
}, [data, isLoading, t, convertModel]);
if (data?.format !== 'checkpoint') {
return;

View File

@ -72,10 +72,12 @@ export const ModelEdit = ({ form }: Props) => {
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={form.control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
</FormControl>
{data.type === 'main' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
</FormControl>
)}
{data.type === 'main' && data.format === 'checkpoint' && (
<>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>

View File

@ -3,7 +3,6 @@ import 'reactflow/dist/style.css';
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
@ -24,6 +23,7 @@ import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { toast } from 'features/toast/toast';
import { filter, map, memoize, some } from 'lodash-es';
import { memo, useCallback, useMemo, useRef } from 'react';
import { flushSync } from 'react-dom';
@ -60,7 +60,6 @@ const filterOption = memoize((option: FilterOptionOption<ComboboxOption>, inputV
const AddNodePopover = () => {
const dispatch = useAppDispatch();
const buildInvocation = useBuildNode();
const toaster = useAppToaster();
const { t } = useTranslation();
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null);
@ -127,7 +126,7 @@ const AddNodePopover = () => {
const errorMessage = t('nodes.unknownNode', {
nodeType: nodeType,
});
toaster({
toast({
status: 'error',
title: errorMessage,
});
@ -163,7 +162,7 @@ const AddNodePopover = () => {
}
return node;
},
[buildInvocation, store, dispatch, t, toaster]
[buildInvocation, store, dispatch, t]
);
const onChange = useCallback<ComboboxOnChange>(

View File

@ -1,7 +1,7 @@
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames';
import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames';
import { InvocationInputFieldCheck } from 'features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck';
import { useFieldNames } from 'features/nodes/hooks/useFieldNames';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
import { memo } from 'react';
@ -20,8 +20,7 @@ type Props = {
};
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId);
const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId);
const fieldNames = useFieldNames(nodeId);
const withFooter = useWithFooter(nodeId);
const outputFieldNames = useOutputFieldNames(nodeId);
@ -41,9 +40,11 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
>
<Flex flexDir="column" px={2} w="full" h="full">
<Grid gridTemplateColumns="1fr auto" gridAutoRows="1fr">
{inputConnectionFieldNames.map((fieldName, i) => (
{fieldNames.connectionFields.map((fieldName, i) => (
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
<InputField nodeId={nodeId} fieldName={fieldName} />
</InvocationInputFieldCheck>
</GridItem>
))}
{outputFieldNames.map((fieldName, i) => (
@ -52,8 +53,23 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
</GridItem>
))}
</Grid>
{inputAnyOrDirectFieldNames.map((fieldName) => (
<InputField key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName} />
{fieldNames.anyOrDirectFields.map((fieldName) => (
<InvocationInputFieldCheck
key={`${nodeId}.${fieldName}.input-field`}
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} />
</InvocationInputFieldCheck>
))}
{fieldNames.missingFields.map((fieldName) => (
<InvocationInputFieldCheck
key={`${nodeId}.${fieldName}.input-field`}
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} />
</InvocationInputFieldCheck>
))}
</Flex>
</Flex>

Some files were not shown because too many files have changed in this diff Show More