diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py
index 721760b222..9dc1429d92 100644
--- a/invokeai/app/cli_app.py
+++ b/invokeai/app/cli_app.py
@@ -102,6 +102,29 @@ def generate_matching_edges(
return edges
+class SessionError(Exception):
+ """Raised when a session error has occurred"""
+ pass
+
+
+def invoke_all(context: CliContext):
+ """Runs all invocations in the specified session"""
+ context.invoker.invoke(context.session, invoke_all=True)
+ while not context.session.is_complete():
+ # Wait some time
+ session = context.get_session()
+ time.sleep(0.1)
+
+ # Print any errors
+ if context.session.has_error():
+ for n in context.session.errors:
+ print(
+ f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
+ )
+
+ raise SessionError()
+
+
def invoke_cli():
args = Args()
config = args.parse_args()
@@ -134,7 +157,6 @@ def invoke_cli():
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
-
parser = get_command_parser()
# Uncomment to print out previous sessions at startup
@@ -151,8 +173,7 @@ def invoke_cli():
try:
# Refresh the state of the session
- session = invoker.services.graph_execution_manager.get(session.id)
- history = list(get_graph_execution_history(session))
+ history = list(get_graph_execution_history(context.session))
# Split the command for piping
cmds = cmd_input.split("|")
@@ -164,7 +185,7 @@ def invoke_cli():
raise InvalidArgs("Empty command")
# Parse args to create invocation
- args = vars(parser.parse_args(shlex.split(cmd.strip())))
+ args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
# Override defaults
for field_name, field_default in context.defaults.items():
@@ -176,11 +197,11 @@ def invoke_cli():
command = CliCommand(command=args)
# Run any CLI commands immediately
- # TODO: this won't behave as expected if piping and using e.g. history,
- # since invocations are gathered and then run together at the end.
- # This is more efficient if the CLI is running against a distributed
- # backend, so it's preferable not to change that behavior.
if isinstance(command.command, BaseCommand):
+ # Invoke all current nodes to preserve operation order
+ invoke_all(context)
+
+ # Run the command
command.command.run(context)
continue
@@ -193,7 +214,7 @@ def invoke_cli():
from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id
- else session.graph.get_node(from_id)
+ else context.session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.command
@@ -203,7 +224,7 @@ def invoke_cli():
# Parse provided links
if "link_node" in args and args["link_node"]:
for link in args["link_node"]:
- link_node = session.graph.get_node(link)
+ link_node = context.session.graph.get_node(link)
matching_edges = generate_matching_edges(
link_node, command.command
)
@@ -227,37 +248,24 @@ def invoke_cli():
current_id = current_id + 1
- # Command line was parsed successfully
- # Add the invocations to the session
- for invocation in new_invocations:
- session.add_node(invocation[0])
- for edge in invocation[1]:
+ # Add the node to the session
+ context.session.add_node(command.command)
+ for edge in edges:
print(edge)
- session.add_edge(edge)
+ context.session.add_edge(edge)
- # Execute all available invocations
- invoker.invoke(session, invoke_all=True)
- while not session.is_complete():
- # Wait some time
- session = context.get_session()
- time.sleep(0.1)
-
- # Print any errors
- if session.has_error():
- for n in session.errors:
- print(
- f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
- )
-
- # Start a new session
- print("Creating a new session")
- session = invoker.create_execution_state()
- context.session = session
+ # Execute all remaining nodes
+ invoke_all(context)
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
+ except SessionError:
+ # Start a new session
+ print("Session error: creating a new session")
+ context.session = context.invoker.create_execution_state()
+
except ExitCli:
break
diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py
index a834e9dba3..881d3deaff 100644
--- a/invokeai/backend/generator/base.py
+++ b/invokeai/backend/generator/base.py
@@ -99,6 +99,7 @@ class Generator:
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback,
+ seed=seed,
**kwargs,
)
results = []
@@ -289,9 +290,7 @@ class Generator:
if self.variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0, np.iinfo(np.uint32).max)
- return (seed, initial_noise)
- else:
- return (seed, None)
+ return (seed, initial_noise)
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self, width, height):
diff --git a/invokeai/backend/generator/img2img.py b/invokeai/backend/generator/img2img.py
index f9692b9cc5..b7f90361a0 100644
--- a/invokeai/backend/generator/img2img.py
+++ b/invokeai/backend/generator/img2img.py
@@ -1,8 +1,10 @@
"""
invokeai.backend.generator.img2img descends from .generator
"""
+from typing import Optional
import torch
+from accelerate.utils import set_seed
from diffusers import logging
from ..stable_diffusion import (
@@ -35,6 +37,7 @@ class Img2Img(Generator):
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
+ seed=None,
**kwargs,
):
"""
@@ -65,6 +68,7 @@ class Img2Img(Generator):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
+ # In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings(
init_image,
@@ -73,6 +77,7 @@ class Img2Img(Generator):
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
+ seed=seed
)
if (
pipeline_output.attention_map_saver is not None
@@ -83,7 +88,9 @@ class Img2Img(Generator):
return make_image
- def get_noise_like(self, like: torch.Tensor):
+ def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
+ if seed is not None:
+ set_seed(seed)
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)
diff --git a/invokeai/backend/generator/inpaint.py b/invokeai/backend/generator/inpaint.py
index f7f21b8906..fa7560d43e 100644
--- a/invokeai/backend/generator/inpaint.py
+++ b/invokeai/backend/generator/inpaint.py
@@ -223,6 +223,7 @@ class Inpaint(Img2Img):
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
+ seed=None,
**kwargs,
):
"""
@@ -319,6 +320,7 @@ class Inpaint(Img2Img):
conditioning_data=conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
+ seed=seed
)
if (
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index f1622f86be..6bd1fe339d 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -690,6 +690,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
+ seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
@@ -703,7 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
- noise = noise_func(initial_latents)
+ noise = noise_func(initial_latents, seed)
return self.img2img_from_latents_and_embeddings(
initial_latents,
@@ -731,9 +732,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet),
)
result_latents, result_attention_maps = self.latents_from_embeddings(
- initial_latents,
- num_inference_steps,
- conditioning_data,
+ latents=initial_latents if strength < 1.0 else torch.zeros_like(
+ initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
+ ),
+ num_inference_steps=num_inference_steps,
+ conditioning_data=conditioning_data,
timesteps=timesteps,
noise=noise,
run_id=run_id,
@@ -779,6 +782,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
+ seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype
@@ -802,7 +806,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype
)
- noise = noise_func(init_image_latents)
+ noise = noise_func(init_image_latents, seed)
if mask.dim() == 3:
mask = mask.unsqueeze(0)
@@ -831,9 +835,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
try:
result_latents, result_attention_maps = self.latents_from_embeddings(
- init_image_latents,
- num_inference_steps,
- conditioning_data,
+ latents=init_image_latents if strength < 1.0 else torch.zeros_like(
+ init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
+ ),
+ num_inference_steps=num_inference_steps,
+ conditioning_data=conditioning_data,
noise=noise,
timesteps=timesteps,
additional_guidance=guidance,
diff --git a/invokeai/frontend/web/src/component.tsx b/invokeai/frontend/web/src/component.tsx
new file mode 100644
index 0000000000..8bd85b50e1
--- /dev/null
+++ b/invokeai/frontend/web/src/component.tsx
@@ -0,0 +1,38 @@
+import React, { lazy } from 'react';
+import { Provider } from 'react-redux';
+import { PersistGate } from 'redux-persist/integration/react';
+import { store } from './app/store';
+import { persistor } from './persistor';
+import '@fontsource/inter/100.css';
+import '@fontsource/inter/200.css';
+import '@fontsource/inter/300.css';
+import '@fontsource/inter/400.css';
+import '@fontsource/inter/500.css';
+import '@fontsource/inter/600.css';
+import '@fontsource/inter/700.css';
+import '@fontsource/inter/800.css';
+import '@fontsource/inter/900.css';
+
+import Loading from './Loading';
+
+// Localization
+import './i18n';
+
+const App = lazy(() => import('./app/App'));
+const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
+
+export default function Component() {
+ return (
+
+
+ } persistor={persistor}>
+ }>
+
+
+
+
+
+
+
+ );
+}
diff --git a/invokeai/frontend/web/src/main.tsx b/invokeai/frontend/web/src/main.tsx
index edbca0c368..e0423d7e4e 100644
--- a/invokeai/frontend/web/src/main.tsx
+++ b/invokeai/frontend/web/src/main.tsx
@@ -1,37 +1,7 @@
-import React, { lazy } from 'react';
import ReactDOM from 'react-dom/client';
-import { Provider } from 'react-redux';
-import { PersistGate } from 'redux-persist/integration/react';
-import { store } from './app/store';
-import { persistor } from './persistor';
-import '@fontsource/inter/100.css';
-import '@fontsource/inter/200.css';
-import '@fontsource/inter/300.css';
-import '@fontsource/inter/400.css';
-import '@fontsource/inter/500.css';
-import '@fontsource/inter/600.css';
-import '@fontsource/inter/700.css';
-import '@fontsource/inter/800.css';
-import '@fontsource/inter/900.css';
-import Loading from './Loading';
-
-// Localization
-import './i18n';
-
-const App = lazy(() => import('./app/App'));
-const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
+import Component from './component';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
-
-
- } persistor={persistor}>
- }>
-
-
-
-
-
-
-
+
);
diff --git a/pyproject.toml b/pyproject.toml
index db29c53883..9534d0ce07 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -63,6 +63,7 @@ dependencies = [
"prompt-toolkit",
"pypatchmatch",
"pyreadline3",
+ "pytorch-lightning==1.7.7",
"realesrgan",
"requests==2.28.2",
"rich~=13.3",