mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'feat_longer_prompts' of github.com:damian0815/InvokeAI into feat_longer_prompts
This commit is contained in:
commit
8076c1697c
@ -102,6 +102,29 @@ def generate_matching_edges(
|
|||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
class SessionError(Exception):
|
||||||
|
"""Raised when a session error has occurred"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_all(context: CliContext):
|
||||||
|
"""Runs all invocations in the specified session"""
|
||||||
|
context.invoker.invoke(context.session, invoke_all=True)
|
||||||
|
while not context.session.is_complete():
|
||||||
|
# Wait some time
|
||||||
|
session = context.get_session()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Print any errors
|
||||||
|
if context.session.has_error():
|
||||||
|
for n in context.session.errors:
|
||||||
|
print(
|
||||||
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise SessionError()
|
||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
args = Args()
|
args = Args()
|
||||||
config = args.parse_args()
|
config = args.parse_args()
|
||||||
@ -134,7 +157,6 @@ def invoke_cli():
|
|||||||
|
|
||||||
invoker = Invoker(services)
|
invoker = Invoker(services)
|
||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
|
|
||||||
parser = get_command_parser()
|
parser = get_command_parser()
|
||||||
|
|
||||||
# Uncomment to print out previous sessions at startup
|
# Uncomment to print out previous sessions at startup
|
||||||
@ -151,8 +173,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Refresh the state of the session
|
# Refresh the state of the session
|
||||||
session = invoker.services.graph_execution_manager.get(session.id)
|
history = list(get_graph_execution_history(context.session))
|
||||||
history = list(get_graph_execution_history(session))
|
|
||||||
|
|
||||||
# Split the command for piping
|
# Split the command for piping
|
||||||
cmds = cmd_input.split("|")
|
cmds = cmd_input.split("|")
|
||||||
@ -164,7 +185,7 @@ def invoke_cli():
|
|||||||
raise InvalidArgs("Empty command")
|
raise InvalidArgs("Empty command")
|
||||||
|
|
||||||
# Parse args to create invocation
|
# Parse args to create invocation
|
||||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
||||||
|
|
||||||
# Override defaults
|
# Override defaults
|
||||||
for field_name, field_default in context.defaults.items():
|
for field_name, field_default in context.defaults.items():
|
||||||
@ -176,11 +197,11 @@ def invoke_cli():
|
|||||||
command = CliCommand(command=args)
|
command = CliCommand(command=args)
|
||||||
|
|
||||||
# Run any CLI commands immediately
|
# Run any CLI commands immediately
|
||||||
# TODO: this won't behave as expected if piping and using e.g. history,
|
|
||||||
# since invocations are gathered and then run together at the end.
|
|
||||||
# This is more efficient if the CLI is running against a distributed
|
|
||||||
# backend, so it's preferable not to change that behavior.
|
|
||||||
if isinstance(command.command, BaseCommand):
|
if isinstance(command.command, BaseCommand):
|
||||||
|
# Invoke all current nodes to preserve operation order
|
||||||
|
invoke_all(context)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
command.command.run(context)
|
command.command.run(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -193,7 +214,7 @@ def invoke_cli():
|
|||||||
from_node = (
|
from_node = (
|
||||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||||
if current_id != start_id
|
if current_id != start_id
|
||||||
else session.graph.get_node(from_id)
|
else context.session.graph.get_node(from_id)
|
||||||
)
|
)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
from_node, command.command
|
from_node, command.command
|
||||||
@ -203,7 +224,7 @@ def invoke_cli():
|
|||||||
# Parse provided links
|
# Parse provided links
|
||||||
if "link_node" in args and args["link_node"]:
|
if "link_node" in args and args["link_node"]:
|
||||||
for link in args["link_node"]:
|
for link in args["link_node"]:
|
||||||
link_node = session.graph.get_node(link)
|
link_node = context.session.graph.get_node(link)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
link_node, command.command
|
link_node, command.command
|
||||||
)
|
)
|
||||||
@ -227,37 +248,24 @@ def invoke_cli():
|
|||||||
|
|
||||||
current_id = current_id + 1
|
current_id = current_id + 1
|
||||||
|
|
||||||
# Command line was parsed successfully
|
# Add the node to the session
|
||||||
# Add the invocations to the session
|
context.session.add_node(command.command)
|
||||||
for invocation in new_invocations:
|
for edge in edges:
|
||||||
session.add_node(invocation[0])
|
|
||||||
for edge in invocation[1]:
|
|
||||||
print(edge)
|
print(edge)
|
||||||
session.add_edge(edge)
|
context.session.add_edge(edge)
|
||||||
|
|
||||||
# Execute all available invocations
|
# Execute all remaining nodes
|
||||||
invoker.invoke(session, invoke_all=True)
|
invoke_all(context)
|
||||||
while not session.is_complete():
|
|
||||||
# Wait some time
|
|
||||||
session = context.get_session()
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Print any errors
|
|
||||||
if session.has_error():
|
|
||||||
for n in session.errors:
|
|
||||||
print(
|
|
||||||
f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start a new session
|
|
||||||
print("Creating a new session")
|
|
||||||
session = invoker.create_execution_state()
|
|
||||||
context.session = session
|
|
||||||
|
|
||||||
except InvalidArgs:
|
except InvalidArgs:
|
||||||
print('Invalid command, use "help" to list commands')
|
print('Invalid command, use "help" to list commands')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
except SessionError:
|
||||||
|
# Start a new session
|
||||||
|
print("Session error: creating a new session")
|
||||||
|
context.session = context.invoker.create_execution_state()
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -99,6 +99,7 @@ class Generator:
|
|||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||||
attention_maps_callback=attention_maps_callback,
|
attention_maps_callback=attention_maps_callback,
|
||||||
|
seed=seed,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
@ -289,9 +290,7 @@ class Generator:
|
|||||||
if self.variation_amount > 0:
|
if self.variation_amount > 0:
|
||||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
return (seed, initial_noise)
|
return (seed, initial_noise)
|
||||||
else:
|
|
||||||
return (seed, None)
|
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
def get_noise(self, width, height):
|
def get_noise(self, width, height):
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
invokeai.backend.generator.img2img descends from .generator
|
invokeai.backend.generator.img2img descends from .generator
|
||||||
"""
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from diffusers import logging
|
from diffusers import logging
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
from ..stable_diffusion import (
|
||||||
@ -35,6 +37,7 @@ class Img2Img(Generator):
|
|||||||
h_symmetry_time_pct=None,
|
h_symmetry_time_pct=None,
|
||||||
v_symmetry_time_pct=None,
|
v_symmetry_time_pct=None,
|
||||||
attention_maps_callback=None,
|
attention_maps_callback=None,
|
||||||
|
seed=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -65,6 +68,7 @@ class Img2Img(Generator):
|
|||||||
# FIXME: use x_T for initial seeded noise
|
# FIXME: use x_T for initial seeded noise
|
||||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||||
# necessary, which the x_T input might not match.
|
# necessary, which the x_T input might not match.
|
||||||
|
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
|
||||||
logging.set_verbosity_error() # quench safety check warnings
|
logging.set_verbosity_error() # quench safety check warnings
|
||||||
pipeline_output = pipeline.img2img_from_embeddings(
|
pipeline_output = pipeline.img2img_from_embeddings(
|
||||||
init_image,
|
init_image,
|
||||||
@ -73,6 +77,7 @@ class Img2Img(Generator):
|
|||||||
conditioning_data,
|
conditioning_data,
|
||||||
noise_func=self.get_noise_like,
|
noise_func=self.get_noise_like,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
seed=seed
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
pipeline_output.attention_map_saver is not None
|
pipeline_output.attention_map_saver is not None
|
||||||
@ -83,7 +88,9 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
def get_noise_like(self, like: torch.Tensor):
|
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
|
||||||
|
if seed is not None:
|
||||||
|
set_seed(seed)
|
||||||
device = like.device
|
device = like.device
|
||||||
if device.type == "mps":
|
if device.type == "mps":
|
||||||
x = torch.randn_like(like, device="cpu").to(device)
|
x = torch.randn_like(like, device="cpu").to(device)
|
||||||
|
@ -223,6 +223,7 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
attention_maps_callback=None,
|
attention_maps_callback=None,
|
||||||
|
seed=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -319,6 +320,7 @@ class Inpaint(Img2Img):
|
|||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
noise_func=self.get_noise_like,
|
noise_func=self.get_noise_like,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -690,6 +690,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
noise_func=None,
|
noise_func=None,
|
||||||
|
seed=None,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
||||||
@ -703,7 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
device=self._model_group.device_for(self.unet),
|
device=self._model_group.device_for(self.unet),
|
||||||
dtype=self.unet.dtype,
|
dtype=self.unet.dtype,
|
||||||
)
|
)
|
||||||
noise = noise_func(initial_latents)
|
noise = noise_func(initial_latents, seed)
|
||||||
|
|
||||||
return self.img2img_from_latents_and_embeddings(
|
return self.img2img_from_latents_and_embeddings(
|
||||||
initial_latents,
|
initial_latents,
|
||||||
@ -731,9 +732,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
device=self._model_group.device_for(self.unet),
|
device=self._model_group.device_for(self.unet),
|
||||||
)
|
)
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
initial_latents,
|
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||||
num_inference_steps,
|
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||||
conditioning_data,
|
),
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
@ -779,6 +782,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
noise_func=None,
|
noise_func=None,
|
||||||
|
seed=None,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
device = self._model_group.device_for(self.unet)
|
device = self._model_group.device_for(self.unet)
|
||||||
latents_dtype = self.unet.dtype
|
latents_dtype = self.unet.dtype
|
||||||
@ -802,7 +806,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
init_image_latents = self.non_noised_latents_from_image(
|
init_image_latents = self.non_noised_latents_from_image(
|
||||||
init_image, device=device, dtype=latents_dtype
|
init_image, device=device, dtype=latents_dtype
|
||||||
)
|
)
|
||||||
noise = noise_func(init_image_latents)
|
noise = noise_func(init_image_latents, seed)
|
||||||
|
|
||||||
if mask.dim() == 3:
|
if mask.dim() == 3:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
@ -831,9 +835,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
init_image_latents,
|
latents=init_image_latents if strength < 1.0 else torch.zeros_like(
|
||||||
num_inference_steps,
|
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
|
||||||
conditioning_data,
|
),
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
additional_guidance=guidance,
|
additional_guidance=guidance,
|
||||||
|
38
invokeai/frontend/web/src/component.tsx
Normal file
38
invokeai/frontend/web/src/component.tsx
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import React, { lazy } from 'react';
|
||||||
|
import { Provider } from 'react-redux';
|
||||||
|
import { PersistGate } from 'redux-persist/integration/react';
|
||||||
|
import { store } from './app/store';
|
||||||
|
import { persistor } from './persistor';
|
||||||
|
import '@fontsource/inter/100.css';
|
||||||
|
import '@fontsource/inter/200.css';
|
||||||
|
import '@fontsource/inter/300.css';
|
||||||
|
import '@fontsource/inter/400.css';
|
||||||
|
import '@fontsource/inter/500.css';
|
||||||
|
import '@fontsource/inter/600.css';
|
||||||
|
import '@fontsource/inter/700.css';
|
||||||
|
import '@fontsource/inter/800.css';
|
||||||
|
import '@fontsource/inter/900.css';
|
||||||
|
|
||||||
|
import Loading from './Loading';
|
||||||
|
|
||||||
|
// Localization
|
||||||
|
import './i18n';
|
||||||
|
|
||||||
|
const App = lazy(() => import('./app/App'));
|
||||||
|
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||||
|
|
||||||
|
export default function Component() {
|
||||||
|
return (
|
||||||
|
<React.StrictMode>
|
||||||
|
<Provider store={store}>
|
||||||
|
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||||
|
<React.Suspense fallback={<Loading showText />}>
|
||||||
|
<ThemeLocaleProvider>
|
||||||
|
<App />
|
||||||
|
</ThemeLocaleProvider>
|
||||||
|
</React.Suspense>
|
||||||
|
</PersistGate>
|
||||||
|
</Provider>
|
||||||
|
</React.StrictMode>
|
||||||
|
);
|
||||||
|
}
|
@ -1,37 +1,7 @@
|
|||||||
import React, { lazy } from 'react';
|
|
||||||
import ReactDOM from 'react-dom/client';
|
import ReactDOM from 'react-dom/client';
|
||||||
import { Provider } from 'react-redux';
|
|
||||||
import { PersistGate } from 'redux-persist/integration/react';
|
|
||||||
import { store } from './app/store';
|
|
||||||
import { persistor } from './persistor';
|
|
||||||
import '@fontsource/inter/100.css';
|
|
||||||
import '@fontsource/inter/200.css';
|
|
||||||
import '@fontsource/inter/300.css';
|
|
||||||
import '@fontsource/inter/400.css';
|
|
||||||
import '@fontsource/inter/500.css';
|
|
||||||
import '@fontsource/inter/600.css';
|
|
||||||
import '@fontsource/inter/700.css';
|
|
||||||
import '@fontsource/inter/800.css';
|
|
||||||
import '@fontsource/inter/900.css';
|
|
||||||
|
|
||||||
import Loading from './Loading';
|
import Component from './component';
|
||||||
|
|
||||||
// Localization
|
|
||||||
import './i18n';
|
|
||||||
|
|
||||||
const App = lazy(() => import('./app/App'));
|
|
||||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
|
||||||
|
|
||||||
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
|
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
|
||||||
<React.StrictMode>
|
<Component />
|
||||||
<Provider store={store}>
|
|
||||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
|
||||||
<React.Suspense fallback={<Loading showText />}>
|
|
||||||
<ThemeLocaleProvider>
|
|
||||||
<App />
|
|
||||||
</ThemeLocaleProvider>
|
|
||||||
</React.Suspense>
|
|
||||||
</PersistGate>
|
|
||||||
</Provider>
|
|
||||||
</React.StrictMode>
|
|
||||||
);
|
);
|
||||||
|
@ -63,6 +63,7 @@ dependencies = [
|
|||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"pypatchmatch",
|
"pypatchmatch",
|
||||||
"pyreadline3",
|
"pyreadline3",
|
||||||
|
"pytorch-lightning==1.7.7",
|
||||||
"realesrgan",
|
"realesrgan",
|
||||||
"requests==2.28.2",
|
"requests==2.28.2",
|
||||||
"rich~=13.3",
|
"rich~=13.3",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user