mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/model-events
This commit is contained in:
commit
a137f7fe7b
@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
|
|||||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||||
|
|
||||||
--8<-- "invokeai/frontend/web/README.md"
|
--8<-- "invokeai/frontend/web/README.md"
|
||||||
|
|
||||||
|
## Developing InvokeAI in VSCode
|
||||||
|
|
||||||
|
VSCode offers some nice tools:
|
||||||
|
|
||||||
|
- python debugger
|
||||||
|
- automatic `venv` activation
|
||||||
|
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
|
||||||
|
comfort on your macbook)
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
You'll need the
|
||||||
|
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
|
||||||
|
and
|
||||||
|
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
|
||||||
|
extensions installed first.
|
||||||
|
|
||||||
|
It's also really handy to install the `Jupyter` extensions:
|
||||||
|
|
||||||
|
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
|
||||||
|
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
|
||||||
|
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
|
||||||
|
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
|
||||||
|
|
||||||
|
#### InvokeAI workspace
|
||||||
|
|
||||||
|
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
|
||||||
|
can hold InvokeAI-specific settings and configs.
|
||||||
|
|
||||||
|
To make a workspace:
|
||||||
|
|
||||||
|
- Open the InvokeAI repo dir in VSCode
|
||||||
|
- `File` > `Save Workspace As` > save it _outside_ the repo
|
||||||
|
|
||||||
|
#### Default python interpreter (i.e. automatic virtual environment activation)
|
||||||
|
|
||||||
|
- Use command palette to run command
|
||||||
|
`Preferences: Open Workspace Settings (JSON)`
|
||||||
|
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
|
||||||
|
python
|
||||||
|
|
||||||
|
Should look something like this:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
// I like to have all InvokeAI-related folders in my workspace
|
||||||
|
"folders": [
|
||||||
|
{
|
||||||
|
// repo root
|
||||||
|
"path": "InvokeAI"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// InvokeAI root dir, where `invokeai.yaml` lives
|
||||||
|
"path": "/path/to/invokeai_root"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"settings": {
|
||||||
|
// Where your InvokeAI `venv`'s python executable lives
|
||||||
|
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now when you open the VSCode integrated terminal, or do anything that needs to
|
||||||
|
run python, it will automatically be in your InvokeAI virtual environment.
|
||||||
|
|
||||||
|
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
|
||||||
|
for the python interpreter to run in. This will default to your `venv` python,
|
||||||
|
and so you'll have access to the same python environment as the InvokeAI app.
|
||||||
|
|
||||||
|
This is _super_ handy.
|
||||||
|
|
||||||
|
#### Debugging configs with `launch.json`
|
||||||
|
|
||||||
|
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
|
||||||
|
these can be scoped to a workspace or folder.
|
||||||
|
|
||||||
|
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
|
||||||
|
to set up your `launch.json` and try it out.
|
||||||
|
|
||||||
|
Now we can create the InvokeAI debugging configs:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
// Run the InvokeAI backend & serve the pre-built UI
|
||||||
|
"name": "InvokeAI Web",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "scripts/invokeai-web.py",
|
||||||
|
"args": [
|
||||||
|
// Your InvokeAI root dir (where `invokeai.yaml` lives)
|
||||||
|
"--root",
|
||||||
|
"/path/to/invokeai_root",
|
||||||
|
// Access the app from anywhere on your local network
|
||||||
|
"--host",
|
||||||
|
"0.0.0.0"
|
||||||
|
],
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Run the nodes-based CLI
|
||||||
|
"name": "InvokeAI CLI",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "scripts/invokeai-cli.py",
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Run tests
|
||||||
|
"name": "InvokeAI Test",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": ["--capture=no"],
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Run a single test
|
||||||
|
"name": "InvokeAI Single Test",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
// Change this to point to the specific test you are working on
|
||||||
|
"tests/nodes/test_invoker.py"
|
||||||
|
],
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// This is the default, useful to just run a single file
|
||||||
|
"name": "Python: File",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"justMyCode": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You'll see these configs in the debugging configs drop down. Running them will
|
||||||
|
start InvokeAI with attached debugger, in the correct environment, and work just
|
||||||
|
like the normal app.
|
||||||
|
|
||||||
|
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
|
||||||
|
|
||||||
|
#### Remote dev
|
||||||
|
|
||||||
|
This is very easy to set up and provides the same very smooth experience as
|
||||||
|
local development. Environments and debugging, as set up above, just work,
|
||||||
|
though you'd need to recreate the workspace and debugging configs on the remote.
|
||||||
|
|
||||||
|
Consult the
|
||||||
|
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
|
||||||
|
get it set up.
|
||||||
|
|
||||||
|
Suggest using VSCode's included settings sync so that your remote dev host has
|
||||||
|
all the same app settings and extensions automagically.
|
||||||
|
|
||||||
|
##### One remote dev gotcha
|
||||||
|
|
||||||
|
I've found the automatic port forwarding to be very flakey. You can disable it
|
||||||
|
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
|
||||||
|
`remote.autoForwardPorts` and untick the box.
|
||||||
|
|
||||||
|
To forward ports very reliably, use SSH on the remote dev client (e.g. your
|
||||||
|
macbook). Here's how to forward both backend API port (`9090`) and the frontend
|
||||||
|
live dev server port (`5173`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh \
|
||||||
|
-L 9090:localhost:9090 \
|
||||||
|
-L 5173:localhost:5173 \
|
||||||
|
user@remote-dev-host
|
||||||
|
```
|
||||||
|
|
||||||
|
The forwarding stops when you close the terminal window, so suggest to do this
|
||||||
|
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
|
||||||
|
an extension update or something
|
||||||
|
|
||||||
|
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
|
||||||
|
now served from the remote dev host, just the same as if it was running on the
|
||||||
|
client.
|
||||||
|
@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
**kwargs,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device('cpu')
|
||||||
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
additional_guidance=additional_guidance,
|
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
**kwargs,
|
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
@ -505,42 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
|
||||||
|
|
||||||
if cond.shape[1] < max_len:
|
|
||||||
conditioning_attention_mask = torch.cat([
|
|
||||||
conditioning_attention_mask,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
cond = torch.cat([
|
|
||||||
cond,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = torch.cat([
|
|
||||||
encoder_attention_mask,
|
|
||||||
conditioning_attention_mask,
|
|
||||||
])
|
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
|
||||||
|
|
||||||
encoder_attention_mask = None
|
|
||||||
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
|
|
||||||
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
|
|
||||||
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
|
|
||||||
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
|
|
||||||
)
|
|
||||||
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
|
|
||||||
conditioning_data.text_embeddings, max_len, encoder_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
@ -580,8 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -623,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -638,8 +597,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
# TODO: rewrite to pass with conditionings
|
|
||||||
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
@ -669,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings
|
encoder_hidden_states = conditioning_data.text_embeddings
|
||||||
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||||
conditioning_data.text_embeddings])
|
conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings,
|
||||||
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
|
@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat([
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
cond = torch.cat([
|
||||||
|
cond,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat([
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
])
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||||
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||||
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
|
unconditioning, conditioning
|
||||||
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
x_twice, sigma_twice, both_conditionings,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
@ -260,8 +297,32 @@ class InvokeAIDiffuserComponent:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
uncond_down_block, cond_down_block = None, None
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||||
|
if down_block_additional_residuals is not None:
|
||||||
|
uncond_down_block, cond_down_block = [], []
|
||||||
|
for down_block in down_block_additional_residuals:
|
||||||
|
_uncond_down, _cond_down = down_block.chunk(2)
|
||||||
|
uncond_down_block.append(_uncond_down)
|
||||||
|
cond_down_block.append(_cond_down)
|
||||||
|
|
||||||
|
uncond_mid_block, cond_mid_block = None, None
|
||||||
|
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||||
|
if mid_block_additional_residual is not None:
|
||||||
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
|
x, sigma, unconditioning,
|
||||||
|
down_block_additional_residuals=uncond_down_block,
|
||||||
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
conditioned_next_x = self.model_forward_callback(
|
||||||
|
x, sigma, conditioning,
|
||||||
|
down_block_additional_residuals=cond_down_block,
|
||||||
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
# TODO: looks unused
|
# TODO: looks unused
|
||||||
@ -295,6 +356,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
|
uncond_down_block, cond_down_block = None, None
|
||||||
|
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||||
|
if down_block_additional_residuals is not None:
|
||||||
|
uncond_down_block, cond_down_block = [], []
|
||||||
|
for down_block in down_block_additional_residuals:
|
||||||
|
_uncond_down, _cond_down = down_block.chunk(2)
|
||||||
|
uncond_down_block.append(_uncond_down)
|
||||||
|
cond_down_block.append(_cond_down)
|
||||||
|
|
||||||
|
uncond_mid_block, cond_mid_block = None, None
|
||||||
|
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||||
|
if mid_block_additional_residual is not None:
|
||||||
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
cross_attn_processor_context = SwapCrossAttnContext(
|
cross_attn_processor_context = SwapCrossAttnContext(
|
||||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||||
index_map=context.cross_attention_index_map,
|
index_map=context.cross_attention_index_map,
|
||||||
@ -307,6 +382,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
down_block_additional_residuals=uncond_down_block,
|
||||||
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -319,6 +396,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
down_block_additional_residuals=cond_down_block,
|
||||||
|
mid_block_additional_residual=cond_mid_block,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
@ -2,9 +2,12 @@ import { ButtonGroup } from '@chakra-ui/react';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa';
|
import { FaCode, FaExpand, FaMinus, FaPlus, FaInfo } from 'react-icons/fa';
|
||||||
import { useReactFlow } from 'reactflow';
|
import { useReactFlow } from 'reactflow';
|
||||||
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice';
|
import {
|
||||||
|
shouldShowGraphOverlayChanged,
|
||||||
|
shouldShowFieldTypeLegendChanged,
|
||||||
|
} from '../store/nodesSlice';
|
||||||
|
|
||||||
const ViewportControls = () => {
|
const ViewportControls = () => {
|
||||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||||
@ -12,6 +15,9 @@ const ViewportControls = () => {
|
|||||||
const shouldShowGraphOverlay = useAppSelector(
|
const shouldShowGraphOverlay = useAppSelector(
|
||||||
(state) => state.nodes.shouldShowGraphOverlay
|
(state) => state.nodes.shouldShowGraphOverlay
|
||||||
);
|
);
|
||||||
|
const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
|
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||||
|
);
|
||||||
|
|
||||||
const handleClickedZoomIn = useCallback(() => {
|
const handleClickedZoomIn = useCallback(() => {
|
||||||
zoomIn();
|
zoomIn();
|
||||||
@ -29,6 +35,10 @@ const ViewportControls = () => {
|
|||||||
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
||||||
}, [shouldShowGraphOverlay, dispatch]);
|
}, [shouldShowGraphOverlay, dispatch]);
|
||||||
|
|
||||||
|
const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
||||||
|
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
||||||
|
}, [shouldShowFieldTypeLegend, dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ButtonGroup isAttached orientation="vertical">
|
<ButtonGroup isAttached orientation="vertical">
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
@ -52,6 +62,12 @@ const ViewportControls = () => {
|
|||||||
aria-label="Show/Hide Graph"
|
aria-label="Show/Hide Graph"
|
||||||
icon={<FaCode />}
|
icon={<FaCode />}
|
||||||
/>
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
isChecked={shouldShowFieldTypeLegend}
|
||||||
|
onClick={handleClickedToggleFieldTypeLegend}
|
||||||
|
aria-label="Show/Hide Field Type Legend"
|
||||||
|
icon={<FaInfo />}
|
||||||
|
/>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -9,10 +9,13 @@ const TopRightPanel = () => {
|
|||||||
const shouldShowGraphOverlay = useAppSelector(
|
const shouldShowGraphOverlay = useAppSelector(
|
||||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
||||||
);
|
);
|
||||||
|
const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
|
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Panel position="top-right">
|
<Panel position="top-right">
|
||||||
<FieldTypeLegend />
|
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
||||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
||||||
|
@ -32,6 +32,7 @@ export type NodesState = {
|
|||||||
invocationTemplates: Record<string, InvocationTemplate>;
|
invocationTemplates: Record<string, InvocationTemplate>;
|
||||||
connectionStartParams: OnConnectStartParams | null;
|
connectionStartParams: OnConnectStartParams | null;
|
||||||
shouldShowGraphOverlay: boolean;
|
shouldShowGraphOverlay: boolean;
|
||||||
|
shouldShowFieldTypeLegend: boolean;
|
||||||
editorInstance: ReactFlowInstance | undefined;
|
editorInstance: ReactFlowInstance | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ export const initialNodesState: NodesState = {
|
|||||||
invocationTemplates: {},
|
invocationTemplates: {},
|
||||||
connectionStartParams: null,
|
connectionStartParams: null,
|
||||||
shouldShowGraphOverlay: false,
|
shouldShowGraphOverlay: false,
|
||||||
|
shouldShowFieldTypeLegend: false,
|
||||||
editorInstance: undefined,
|
editorInstance: undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -125,6 +127,12 @@ const nodesSlice = createSlice({
|
|||||||
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldShowGraphOverlay = action.payload;
|
state.shouldShowGraphOverlay = action.payload;
|
||||||
},
|
},
|
||||||
|
shouldShowFieldTypeLegendChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<boolean>
|
||||||
|
) => {
|
||||||
|
state.shouldShowFieldTypeLegend = action.payload;
|
||||||
|
},
|
||||||
nodeTemplatesBuilt: (
|
nodeTemplatesBuilt: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<Record<string, InvocationTemplate>>
|
action: PayloadAction<Record<string, InvocationTemplate>>
|
||||||
@ -161,6 +169,7 @@ export const {
|
|||||||
connectionStarted,
|
connectionStarted,
|
||||||
connectionEnded,
|
connectionEnded,
|
||||||
shouldShowGraphOverlayChanged,
|
shouldShowGraphOverlayChanged,
|
||||||
|
shouldShowFieldTypeLegendChanged,
|
||||||
nodeTemplatesBuilt,
|
nodeTemplatesBuilt,
|
||||||
nodeEditorReset,
|
nodeEditorReset,
|
||||||
imageCollectionFieldValueChanged,
|
imageCollectionFieldValueChanged,
|
||||||
|
Loading…
Reference in New Issue
Block a user