Merge branch 'main' into mm-ui

This commit is contained in:
blessedcoolant 2023-07-18 12:58:57 +12:00
commit ec3c15ead0
29 changed files with 833 additions and 6996 deletions

View File

@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.--> <!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md" --8<-- "invokeai/frontend/web/README.md"
## Developing InvokeAI in VSCode
VSCode offers some nice tools:
- python debugger
- automatic `venv` activation
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
comfort on your macbook)
### Setup
You'll need the
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
and
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
extensions installed first.
It's also really handy to install the `Jupyter` extensions:
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
#### InvokeAI workspace
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
can hold InvokeAI-specific settings and configs.
To make a workspace:
- Open the InvokeAI repo dir in VSCode
- `File` > `Save Workspace As` > save it _outside_ the repo
#### Default python interpreter (i.e. automatic virtual environment activation)
- Use command palette to run command
`Preferences: Open Workspace Settings (JSON)`
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
python
Should look something like this:
```json
{
// I like to have all InvokeAI-related folders in my workspace
"folders": [
{
// repo root
"path": "InvokeAI"
},
{
// InvokeAI root dir, where `invokeai.yaml` lives
"path": "/path/to/invokeai_root"
}
],
"settings": {
// Where your InvokeAI `venv`'s python executable lives
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
}
}
```
Now when you open the VSCode integrated terminal, or do anything that needs to
run python, it will automatically be in your InvokeAI virtual environment.
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
for the python interpreter to run in. This will default to your `venv` python,
and so you'll have access to the same python environment as the InvokeAI app.
This is _super_ handy.
#### Debugging configs with `launch.json`
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
these can be scoped to a workspace or folder.
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
to set up your `launch.json` and try it out.
Now we can create the InvokeAI debugging configs:
```json
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
// Run the InvokeAI backend & serve the pre-built UI
"name": "InvokeAI Web",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-web.py",
"args": [
// Your InvokeAI root dir (where `invokeai.yaml` lives)
"--root",
"/path/to/invokeai_root",
// Access the app from anywhere on your local network
"--host",
"0.0.0.0"
],
"justMyCode": true
},
{
// Run the nodes-based CLI
"name": "InvokeAI CLI",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-cli.py",
"justMyCode": true
},
{
// Run tests
"name": "InvokeAI Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["--capture=no"],
"justMyCode": true
},
{
// Run a single test
"name": "InvokeAI Single Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": [
// Change this to point to the specific test you are working on
"tests/nodes/test_invoker.py"
],
"justMyCode": true
},
{
// This is the default, useful to just run a single file
"name": "Python: File",
"type": "python",
"request": "launch",
"program": "${file}",
"justMyCode": true
}
]
}
```
You'll see these configs in the debugging configs drop down. Running them will
start InvokeAI with attached debugger, in the correct environment, and work just
like the normal app.
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
#### Remote dev
This is very easy to set up and provides the same very smooth experience as
local development. Environments and debugging, as set up above, just work,
though you'd need to recreate the workspace and debugging configs on the remote.
Consult the
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
get it set up.
Suggest using VSCode's included settings sync so that your remote dev host has
all the same app settings and extensions automagically.
##### One remote dev gotcha
I've found the automatic port forwarding to be very flakey. You can disable it
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
`remote.autoForwardPorts` and untick the box.
To forward ports very reliably, use SSH on the remote dev client (e.g. your
macbook). Here's how to forward both backend API port (`9090`) and the frontend
live dev server port (`5173`):
```bash
ssh \
-L 9090:localhost:9090 \
-L 5173:localhost:5173 \
user@remote-dev-host
```
The forwarding stops when you close the terminal window, so suggest to do this
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
an extension update or something
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
now served from the remote dev host, just the same as if it was running on the
client.

View File

@ -57,10 +57,10 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
@ -82,6 +82,7 @@ class CompelInvocation(BaseInvocation):
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:

View File

@ -157,13 +157,13 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context,)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
with vae_info as vae,\ with vae_info as vae,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\

View File

@ -76,7 +76,7 @@ def get_scheduler(
scheduler_name, SCHEDULER_MAP['ddim'] scheduler_name, SCHEDULER_MAP['ddim']
) )
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict() **scheduler_info.dict(), context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation):
model_name=control_info.control_model.model_name, model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet, model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model, base_model=control_info.control_model.base_model,
context=context,
) )
) )
@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

View File

@ -105,8 +105,6 @@ class EventServiceBase:
def emit_model_load_started ( def emit_model_load_started (
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -117,8 +115,6 @@ class EventServiceBase:
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -129,8 +125,6 @@ class EventServiceBase:
def emit_model_load_completed( def emit_model_load_completed(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -142,12 +136,12 @@ class EventServiceBase:
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, hash=model_info.hash,
location=model_info.location,
precision=str(model_info.precision),
), ),
) )

View File

@ -339,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> ModelInfo: ) -> ModelInfo:
""" """
@ -347,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# if we are called from within a node, then we get to emit # we can emit model loading events if we are executing with access to the invocation context
# load start and complete events if context:
if node and context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -366,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
submodel, submodel,
) )
if node and context: if context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -510,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event( def _emit_load_event(
self, self,
node,
context, context,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None, model_info: Optional[ModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException() raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info: if model_info:
context.services.events.emit_model_load_completed( context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -536,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,

View File

@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu') scheduler_device = torch.device('cpu')
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps, timesteps,
conditioning_data, conditioning_data,
noise=noise, noise=noise,
additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
**kwargs,
callback=callback,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -505,42 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
): ):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
encoder_attention_mask = None
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
)
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
conditioning_data.text_embeddings, max_len, encoder_attention_mask
)
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
@ -580,8 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -623,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -638,8 +597,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
# TODO: rewrite to pass with conditionings
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list)
@ -669,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = conditioning_data.text_embeddings encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else: else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.text_embeddings]) conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]

View File

@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent:
) )
return latents return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs, x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -260,8 +297,32 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) uncond_down_block, cond_down_block = None, None
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
# TODO: looks unused # TODO: looks unused
@ -295,6 +356,20 @@ class InvokeAIDiffuserComponent:
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext( cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning, modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map, index_map=context.cross_attention_index_map,
@ -307,6 +382,8 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs, **kwargs,
) )
@ -319,6 +396,8 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs, **kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x

View File

@ -577,6 +577,7 @@
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"downloadImageStarted": "Image Download Started", "downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied", "imageCopied": "Image Copied",
"problemCopyingImage": "Unable to Copy Image",
"imageLinkCopied": "Image Link Copied", "imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link", "problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded", "imageNotLoaded": "No Image Loaded",

View File

@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -177,6 +179,8 @@ addSocketConnectedListener();
addSocketDisconnectedListener(); addSocketDisconnectedListener();
addSocketSubscribedListener(); addSocketSubscribedListener();
addSocketUnsubscribedListener(); addSocketUnsubscribedListener();
addModelLoadStartedEventListener();
addModelLoadCompletedEventListener();
// Session Created // Session Created
addSessionCreatedPendingListener(); addSessionCreatedPendingListener();

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadCompleted,
socketModelLoadCompleted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadCompletedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadStarted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadStartedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load started (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
};

View File

@ -21,6 +21,7 @@ import { ImageDTO } from 'services/api/types';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable'; import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable'; import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = { type IAIDndImageProps = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
@ -96,119 +97,124 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}; };
return ( return (
<Flex <ImageContextMenu imageDTO={imageDTO}>
sx={{ {(ref) => (
width: 'full',
height: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}}
>
{imageDTO && (
<Flex <Flex
ref={ref}
sx={{ sx={{
w: 'full', width: 'full',
h: 'full', height: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}} }}
> >
<Image {imageDTO && (
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url} <Flex
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{ sx={{
boxSize: 16, w: 'full',
h: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex>
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}} }}
/> />
</Flex> )}
</> </Flex>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} </ImageContextMenu>
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex>
); );
}; };

View File

@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
/** /**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime. * Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/ */
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime'); export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@ -48,6 +48,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover'; import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions'; import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton'; import IAICanvasUndoButton from './IAICanvasUndoButton';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
export const selector = createSelector( export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector], [systemSelector, canvasSelector, isStagingSelector],
@ -79,6 +80,7 @@ const IAICanvasToolbar = () => {
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
const { t } = useTranslation(); const { t } = useTranslation();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const { openUploader } = useImageUploader(); const { openUploader } = useImageUploader();
@ -136,10 +138,10 @@ const IAICanvasToolbar = () => {
handleCopyImageToClipboard(); handleCopyImageToClipboard();
}, },
{ {
enabled: () => !isStaging, enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true, preventDefault: true,
}, },
[canvasBaseLayer, isProcessing] [canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
); );
useHotkeys( useHotkeys(
@ -189,6 +191,9 @@ const IAICanvasToolbar = () => {
}; };
const handleCopyImageToClipboard = () => { const handleCopyImageToClipboard = () => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard()); dispatch(canvasCopiedToClipboard());
}; };
@ -256,13 +261,15 @@ const IAICanvasToolbar = () => {
onClick={handleSaveToGallery} onClick={handleSaveToGallery}
isDisabled={isStaging} isDisabled={isStaging}
/> />
<IAIIconButton {isClipboardAPIAvailable && (
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`} <IAIIconButton
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`} aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
icon={<FaCopy />} tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
onClick={handleCopyImageToClipboard} icon={<FaCopy />}
isDisabled={isStaging} onClick={handleCopyImageToClipboard}
/> isDisabled={isStaging}
/>
)}
<IAIIconButton <IAIIconButton
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`} aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`} tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}

View File

@ -1,7 +1,16 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { ButtonGroup, Flex, FlexProps, Link } from '@chakra-ui/react'; import {
ButtonGroup,
Flex,
FlexProps,
Link,
Menu,
MenuButton,
MenuItem,
MenuList,
} from '@chakra-ui/react';
// import { runESRGAN, runFacetool } from 'app/socketio/actions'; // import { runESRGAN, runFacetool } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -20,6 +29,7 @@ import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/U
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
@ -48,6 +58,8 @@ import {
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce'; import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { menuListMotionProps } from 'theme/components/menu';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -120,6 +132,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
@ -128,7 +143,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
500 500
); );
const { currentData: image, isFetching } = useGetImageDTOQuery( const { currentData: imageDTO, isFetching } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken lastSelectedImage ?? skipToken
); );
@ -142,15 +157,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleCopyImageLink = useCallback(() => { const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => { const getImageUrl = () => {
if (!image) { if (!imageDTO) {
return; return;
} }
if (image.image_url.startsWith('http')) { if (imageDTO.image_url.startsWith('http')) {
return image.image_url; return imageDTO.image_url;
} }
return window.location.toString() + image.image_url; return window.location.toString() + imageDTO.image_url;
}; };
const url = getImageUrl(); const url = getImageUrl();
@ -174,7 +189,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
isClosable: true, isClosable: true,
}); });
}); });
}, [toaster, t, image]); }, [toaster, t, imageDTO]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata); recallAllParameters(metadata);
@ -192,31 +207,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]); }, [metadata?.seed, recallSeed]);
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); }, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [imageDTO]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img()); dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image)); dispatch(initialImageSelected(imageDTO));
}, [dispatch, image]); }, [dispatch, imageDTO]);
useHotkeys('shift+i', handleSendToImageToImage, [image]); useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
const handleClickUpscale = useCallback(() => { const handleClickUpscale = useCallback(() => {
// selectedImage && dispatch(runESRGAN(selectedImage)); // selectedImage && dispatch(runESRGAN(selectedImage));
}, []); }, []);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
if (!image) { if (!imageDTO) {
return; return;
} }
dispatch(imageToDeleteSelected(image)); dispatch(imageToDeleteSelected(imageDTO));
}, [dispatch, image]); }, [dispatch, imageDTO]);
useHotkeys( useHotkeys(
'Shift+U', 'Shift+U',
@ -236,7 +251,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, },
[ [
isUpscalingEnabled, isUpscalingEnabled,
image, imageDTO,
isESRGANAvailable, isESRGANAvailable,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
isConnected, isConnected,
@ -268,7 +283,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
[ [
isFaceRestoreEnabled, isFaceRestoreEnabled,
image, imageDTO,
isGFPGANAvailable, isGFPGANAvailable,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
isConnected, isConnected,
@ -283,10 +298,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
); );
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
if (!image) return; if (!imageDTO) return;
dispatch(sentImageToCanvas()); dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image)); dispatch(setInitialCanvasImage(imageDTO));
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') { if (activeTabName !== 'unifiedCanvas') {
@ -299,12 +314,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
duration: 2500, duration: 2500,
isClosable: true, isClosable: true,
}); });
}, [image, dispatch, activeTabName, toaster, t]); }, [imageDTO, dispatch, activeTabName, toaster, t]);
useHotkeys( useHotkeys(
'i', 'i',
() => { () => {
if (image) { if (imageDTO) {
handleClickShowImageDetails(); handleClickShowImageDetails();
} else { } else {
toaster({ toaster({
@ -315,13 +330,20 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}); });
} }
}, },
[image, shouldShowImageDetails, toaster] [imageDTO, shouldShowImageDetails, toaster]
); );
const handleClickProgressImagesToggle = useCallback(() => { const handleClickProgressImagesToggle = useCallback(() => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer)); dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]); }, [dispatch, shouldShowProgressInViewer]);
const handleCopyImage = useCallback(() => {
if (!imageDTO) {
return;
}
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO]);
return ( return (
<> <>
<Flex <Flex
@ -334,63 +356,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{...props} {...props}
> >
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIPopover <Menu>
triggerComponent={ <MenuButton
<IAIIconButton as={IAIIconButton}
aria-label={`${t('parameters.sendTo')}...`} aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`} tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!image} isDisabled={!imageDTO}
icon={<FaShareAlt />} icon={<FaShareAlt />}
/> />
} <MenuList motionProps={menuListMotionProps}>
> {imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}
<Flex </MenuList>
sx={{ </Menu>
flexDirection: 'column',
rowGap: 2,
}}
>
<IAIButton
size="sm"
onClick={handleSendToImageToImage}
leftIcon={<FaShare />}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</IAIButton>
{isCanvasEnabled && (
<IAIButton
size="sm"
onClick={handleSendToCanvas}
leftIcon={<FaShare />}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</IAIButton>
)}
{/* <IAIButton
size="sm"
onClick={handleCopyImage}
leftIcon={<FaCopy />}
>
{t('parameters.copyImage')}
</IAIButton> */}
<IAIButton
size="sm"
onClick={handleCopyImageLink}
leftIcon={<FaCopy />}
>
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={image?.image_url} target="_blank">
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
</Link>
</Flex>
</IAIPopover>
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
@ -443,7 +420,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton <IAIButton
isDisabled={ isDisabled={
!isGFPGANAvailable || !isGFPGANAvailable ||
!image || !imageDTO ||
!(isConnected && !isProcessing) || !(isConnected && !isProcessing) ||
!facetoolStrength !facetoolStrength
} }
@ -474,7 +451,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton <IAIButton
isDisabled={ isDisabled={
!isESRGANAvailable || !isESRGANAvailable ||
!image || !imageDTO ||
!(isConnected && !isProcessing) || !(isConnected && !isProcessing) ||
!upscalingLevel !upscalingLevel
} }

View File

@ -4,13 +4,14 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { memo, useMemo } from 'react'; import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems'; import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems'; import SingleSelectionMenuItems from './SingleSelectionMenuItems';
type Props = { type Props = {
imageDTO: ImageDTO; imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children']; children: ContextMenuProps<HTMLDivElement>['children'];
}; };
@ -31,18 +32,32 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
const { selectionCount } = useAppSelector(selector); const { selectionCount } = useAppSelector(selector);
const handleContextMenu = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.preventDefault();
}, []);
return ( return (
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => ( menuButtonProps={{
<MenuList sx={{ visibility: 'visible !important' }}> bg: 'transparent',
{selectionCount === 1 ? ( _hover: { bg: 'transparent' },
<SingleSelectionMenuItems imageDTO={imageDTO} /> }}
) : ( renderMenu={() =>
<MultipleSelectionMenuItems /> imageDTO ? (
)} <MenuList
</MenuList> sx={{ visibility: 'visible !important' }}
)} motionProps={menuListMotionProps}
onContextMenu={handleContextMenu}
>
{selectionCount === 1 ? (
<SingleSelectionMenuItems imageDTO={imageDTO} />
) : (
<MultipleSelectionMenuItems />
)}
</MenuList>
) : null
}
> >
{children} {children}
</ContextMenu> </ContextMenu>

View File

@ -1,5 +1,4 @@
import { ExternalLinkIcon } from '@chakra-ui/icons'; import { Link, MenuItem } from '@chakra-ui/react';
import { MenuItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
@ -14,11 +13,21 @@ import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletio
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useContext, useMemo } from 'react'; import { memo, useCallback, useContext, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa'; import {
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; FaAsterisk,
FaCopy,
FaDownload,
FaExternalLinkAlt,
FaFolder,
FaQuoteRight,
FaSeedling,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
@ -61,6 +70,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name); const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata; const metadata = currentData?.metadata;
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
@ -130,13 +142,27 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
dispatch(imagesAddedToBatch([imageDTO.image_name])); dispatch(imagesAddedToBatch([imageDTO.image_name]));
}, [dispatch, imageDTO.image_name]); }, [dispatch, imageDTO.image_name]);
const handleCopyImage = useCallback(() => {
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO.image_url]);
return ( return (
<> <>
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}> <Link href={imageDTO.image_url} target="_blank">
{t('common.openInNewTab')} <MenuItem
</MenuItem> icon={<FaExternalLinkAlt />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
</Link>
{isClipboardAPIAvailable && (
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
{t('parameters.copyImage')}
</MenuItem>
)}
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaQuoteRight />}
onClickCapture={handleRecallPrompt} onClickCapture={handleRecallPrompt}
isDisabled={ isDisabled={
metadata?.positive_prompt === undefined && metadata?.positive_prompt === undefined &&
@ -147,14 +173,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaSeedling />}
onClickCapture={handleRecallSeed} onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined} isDisabled={metadata?.seed === undefined}
> >
{t('parameters.useSeed')} {t('parameters.useSeed')}
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaAsterisk />}
onClickCapture={handleUseAllParameters} onClickCapture={handleUseAllParameters}
isDisabled={!metadata} isDisabled={!metadata}
> >
@ -193,6 +219,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
Remove from Board Remove from Board
</MenuItem> </MenuItem>
)} )}
<Link download={true} href={imageDTO.image_url} target="_blank">
<MenuItem icon={<FaDownload />} w="100%">
{t('parameters.downloadImage')}
</MenuItem>
</Link>
<MenuItem <MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }} sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />} icon={<FaTrash />}

View File

@ -2,9 +2,12 @@ import { ButtonGroup } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa'; import { FaCode, FaExpand, FaMinus, FaPlus, FaInfo } from 'react-icons/fa';
import { useReactFlow } from 'reactflow'; import { useReactFlow } from 'reactflow';
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice'; import {
shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
} from '../store/nodesSlice';
const ViewportControls = () => { const ViewportControls = () => {
const { zoomIn, zoomOut, fitView } = useReactFlow(); const { zoomIn, zoomOut, fitView } = useReactFlow();
@ -12,6 +15,9 @@ const ViewportControls = () => {
const shouldShowGraphOverlay = useAppSelector( const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay (state) => state.nodes.shouldShowGraphOverlay
); );
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
const handleClickedZoomIn = useCallback(() => { const handleClickedZoomIn = useCallback(() => {
zoomIn(); zoomIn();
@ -29,6 +35,10 @@ const ViewportControls = () => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay)); dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]); }, [shouldShowGraphOverlay, dispatch]);
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
return ( return (
<ButtonGroup isAttached orientation="vertical"> <ButtonGroup isAttached orientation="vertical">
<IAIIconButton <IAIIconButton
@ -52,6 +62,12 @@ const ViewportControls = () => {
aria-label="Show/Hide Graph" aria-label="Show/Hide Graph"
icon={<FaCode />} icon={<FaCode />}
/> />
<IAIIconButton
isChecked={shouldShowFieldTypeLegend}
onClick={handleClickedToggleFieldTypeLegend}
aria-label="Show/Hide Field Type Legend"
icon={<FaInfo />}
/>
</ButtonGroup> </ButtonGroup>
); );
}; };

View File

@ -9,10 +9,13 @@ const TopRightPanel = () => {
const shouldShowGraphOverlay = useAppSelector( const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay (state: RootState) => state.nodes.shouldShowGraphOverlay
); );
const shouldShowFieldTypeLegend = useAppSelector(
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
);
return ( return (
<Panel position="top-right"> <Panel position="top-right">
<FieldTypeLegend /> {shouldShowFieldTypeLegend && <FieldTypeLegend />}
{shouldShowGraphOverlay && <NodeGraphOverlay />} {shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel> </Panel>
); );

View File

@ -32,6 +32,7 @@ export type NodesState = {
invocationTemplates: Record<string, InvocationTemplate>; invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null; connectionStartParams: OnConnectStartParams | null;
shouldShowGraphOverlay: boolean; shouldShowGraphOverlay: boolean;
shouldShowFieldTypeLegend: boolean;
editorInstance: ReactFlowInstance | undefined; editorInstance: ReactFlowInstance | undefined;
}; };
@ -42,6 +43,7 @@ export const initialNodesState: NodesState = {
invocationTemplates: {}, invocationTemplates: {},
connectionStartParams: null, connectionStartParams: null,
shouldShowGraphOverlay: false, shouldShowGraphOverlay: false,
shouldShowFieldTypeLegend: false,
editorInstance: undefined, editorInstance: undefined,
}; };
@ -125,6 +127,12 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => { shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload; state.shouldShowGraphOverlay = action.payload;
}, },
shouldShowFieldTypeLegendChanged: (
state,
action: PayloadAction<boolean>
) => {
state.shouldShowFieldTypeLegend = action.payload;
},
nodeTemplatesBuilt: ( nodeTemplatesBuilt: (
state, state,
action: PayloadAction<Record<string, InvocationTemplate>> action: PayloadAction<Record<string, InvocationTemplate>>
@ -161,6 +169,7 @@ export const {
connectionStarted, connectionStarted,
connectionEnded, connectionEnded,
shouldShowGraphOverlayChanged, shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
nodeTemplatesBuilt, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
imageCollectionFieldValueChanged, imageCollectionFieldValueChanged,

View File

@ -4,6 +4,8 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors'; import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider'; import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa'; import { FaCopy } from 'react-icons/fa';
@ -11,6 +13,7 @@ import { FaCopy } from 'react-icons/fa';
export default function UnifiedCanvasCopyToClipboard() { export default function UnifiedCanvasCopyToClipboard() {
const isStaging = useAppSelector(isStagingSelector); const isStaging = useAppSelector(isStagingSelector);
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const isProcessing = useAppSelector( const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing (state: RootState) => state.system.isProcessing
@ -25,15 +28,22 @@ export default function UnifiedCanvasCopyToClipboard() {
handleCopyImageToClipboard(); handleCopyImageToClipboard();
}, },
{ {
enabled: () => !isStaging, enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true, preventDefault: true,
}, },
[canvasBaseLayer, isProcessing] [canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
); );
const handleCopyImageToClipboard = () => { const handleCopyImageToClipboard = useCallback(() => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard()); dispatch(canvasCopiedToClipboard());
}; }, [dispatch, isClipboardAPIAvailable]);
if (!isClipboardAPIAvailable) {
return null;
}
return ( return (
<IAIIconButton <IAIIconButton

View File

@ -0,0 +1,52 @@
import { useAppToaster } from 'app/components/Toaster';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const isClipboardAPIAvailable = useMemo(() => {
return Boolean(navigator.clipboard) && Boolean(window.ClipboardItem);
}, []);
const copyImageToClipboard = useCallback(
async (image_url: string) => {
if (!isClipboardAPIAvailable) {
toaster({
title: t('toast.problemCopyingImage'),
description: "Your browser doesn't support the Clipboard API.",
status: 'error',
duration: 2500,
isClosable: true,
});
}
try {
const response = await fetch(image_url);
const blob = await response.blob();
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob,
}),
]);
toaster({
title: t('toast.imageCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
} catch (err) {
toaster({
title: t('toast.problemCopyingImage'),
description: String(err),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[isClipboardAPIAvailable, t, toaster]
);
return { isClipboardAPIAvailable, copyImageToClipboard };
};

View File

@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models // Models
export type ModelType = components['schemas']['ModelType']; export type ModelType = components['schemas']['ModelType'];
export type SubModelType = components['schemas']['SubModelType'];
export type BaseModelType = components['schemas']['BaseModelType']; export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField']; export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];

View File

@ -5,6 +5,8 @@ import {
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
} from 'services/events/types'; } from 'services/events/types';
// Common socket action payload data // Common socket action payload data
@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction<
export const appSocketGeneratorProgress = createAction< export const appSocketGeneratorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent } BaseSocketPayload & { data: GeneratorProgressEvent }
>('socket/appSocketGeneratorProgress'); >('socket/appSocketGeneratorProgress');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/socketModelLoadStarted');
/**
* App-level Model Load Started
*/
export const appSocketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/appSocketModelLoadStarted');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/socketModelLoadCompleted');
/**
* App-level Model Load Completed
*/
export const appSocketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/appSocketModelLoadCompleted');

View File

@ -1,5 +1,11 @@
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { Graph, GraphExecutionState } from '../api/types'; import {
BaseModelType,
Graph,
GraphExecutionState,
ModelType,
SubModelType,
} from '../api/types';
/** /**
* A progress image, we get one for each step in the generation * A progress image, we get one for each step in the generation
@ -25,6 +31,25 @@ export type BaseNode = {
[key: string]: AnyInvocation[keyof AnyInvocation]; [key: string]: AnyInvocation[keyof AnyInvocation];
}; };
export type ModelLoadStartedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
};
export type ModelLoadCompletedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
hash?: string;
location: string;
precision: string;
};
/** /**
* A `generator_progress` socket.io event. * A `generator_progress` socket.io event.
* *
@ -101,6 +126,8 @@ export type ServerToClientEvents = {
graph_execution_state_complete: ( graph_execution_state_complete: (
payload: GraphExecutionStateCompleteEvent payload: GraphExecutionStateCompleteEvent
) => void; ) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
}; };
export type ClientToServerEvents = { export type ClientToServerEvents = {

View File

@ -11,6 +11,8 @@ import {
socketConnected, socketConnected,
socketDisconnected, socketDisconnected,
socketSubscribed, socketSubscribed,
socketModelLoadStarted,
socketModelLoadCompleted,
} from '../actions'; } from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr'; import { Logger } from 'roarr';
@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
socketSubscribed({ socketSubscribed({
sessionId, sessionId,
timestamp: getTimestamp(), timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId, boardId: getState().gallery.selectedBoardId,
}) })
); );
} }
@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}) })
); );
}); });
/**
* Model load started
*/
socket.on('model_load_started', (data) => {
dispatch(
socketModelLoadStarted({
data,
timestamp: getTimestamp(),
})
);
});
/**
* Model load completed
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
timestamp: getTimestamp(),
})
);
});
}; };

View File

@ -1,6 +1,7 @@
import { menuAnatomy } from '@chakra-ui/anatomy'; import { menuAnatomy } from '@chakra-ui/anatomy';
import { createMultiStyleConfigHelpers } from '@chakra-ui/react'; import { createMultiStyleConfigHelpers } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools'; import { mode } from '@chakra-ui/theme-tools';
import { MotionProps } from 'framer-motion';
const { definePartsStyle, defineMultiStyleConfig } = const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(menuAnatomy.keys); createMultiStyleConfigHelpers(menuAnatomy.keys);
@ -21,6 +22,7 @@ const invokeAI = definePartsStyle((props) => ({
}, },
list: { list: {
zIndex: 9999, zIndex: 9999,
color: mode('base.900', 'base.150')(props),
bg: mode('base.200', 'base.800')(props), bg: mode('base.200', 'base.800')(props),
shadow: 'dark-lg', shadow: 'dark-lg',
border: 'none', border: 'none',
@ -35,6 +37,9 @@ const invokeAI = definePartsStyle((props) => ({
_focus: { _focus: {
bg: mode('base.400', 'base.600')(props), bg: mode('base.400', 'base.600')(props),
}, },
svg: {
opacity: 0.5,
},
}, },
})); }));
@ -46,3 +51,28 @@ export const menuTheme = defineMultiStyleConfig({
variant: 'invokeAI', variant: 'invokeAI',
}, },
}); });
export const menuListMotionProps: MotionProps = {
variants: {
enter: {
visibility: 'visible',
opacity: 1,
scale: 1,
transition: {
duration: 0.07,
ease: [0.4, 0, 0.2, 1],
},
},
exit: {
transitionEnd: {
visibility: 'hidden',
},
opacity: 0,
scale: 0.8,
transition: {
duration: 0.07,
easings: 'easeOut',
},
},
},
};

File diff suppressed because it is too large Load Diff