mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
invokeai-b
...
lstein/gen
Author | SHA1 | Date | |
---|---|---|---|
a6efcca78c | |||
6e0c6d9cc9 | |||
a3076cf951 | |||
6696882c71 | |||
17b039e85d |
@ -7,13 +7,15 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import eventlet
|
import eventlet
|
||||||
from pathlib import Path
|
import invokeai.frontend.dist as frontend
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
from compel.prompt_parser import Blend
|
||||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||||
from flask_socketio import SocketIO
|
from flask_socketio import SocketIO
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
@ -22,18 +24,15 @@ from invokeai.backend.modules.get_canvas_generation_mode import (
|
|||||||
get_canvas_generation_mode,
|
get_canvas_generation_mode,
|
||||||
)
|
)
|
||||||
from invokeai.backend.modules.parameters import parameters_to_command
|
from invokeai.backend.modules.parameters import parameters_to_command
|
||||||
import invokeai.frontend.dist as frontend
|
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
|
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer
|
||||||
get_tokenizer
|
|
||||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
|
||||||
from compel.prompt_parser import Blend
|
|
||||||
from ldm.invoke.globals import global_models_dir
|
from ldm.invoke.globals import global_models_dir
|
||||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||||
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||||
|
|
||||||
# Loading Arguments
|
# Loading Arguments
|
||||||
opt = Args()
|
opt = Args()
|
||||||
@ -1063,7 +1062,7 @@ class InvokeAIWebServer:
|
|||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
width *= 8
|
width *= 8
|
||||||
height *= 8
|
height *= 8
|
||||||
img_base64 = image_to_dataURL(image)
|
img_base64 = image_to_dataURL(image, image_format="JPEG")
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"intermediateResult",
|
"intermediateResult",
|
||||||
{
|
{
|
||||||
@ -1685,27 +1684,23 @@ class CanceledException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Returns a copy an image, cropped to a bounding box.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def copy_image_from_bounding_box(
|
def copy_image_from_bounding_box(
|
||||||
image: ImageType, x: int, y: int, width: int, height: int
|
image: ImageType, x: int, y: int, width: int, height: int
|
||||||
) -> ImageType:
|
) -> ImageType:
|
||||||
|
"""
|
||||||
|
Returns a copy an image, cropped to a bounding box.
|
||||||
|
"""
|
||||||
with image as im:
|
with image as im:
|
||||||
bounds = (x, y, x + width, y + height)
|
bounds = (x, y, x + width, y + height)
|
||||||
im_cropped = im.crop(bounds)
|
im_cropped = im.crop(bounds)
|
||||||
return im_cropped
|
return im_cropped
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Converts a base64 image dataURL into an image.
|
|
||||||
The dataURL is split on the first commma.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def dataURL_to_image(dataURL: str) -> ImageType:
|
def dataURL_to_image(dataURL: str) -> ImageType:
|
||||||
|
"""
|
||||||
|
Converts a base64 image dataURL into an image.
|
||||||
|
The dataURL is split on the first comma.
|
||||||
|
"""
|
||||||
image = Image.open(
|
image = Image.open(
|
||||||
io.BytesIO(
|
io.BytesIO(
|
||||||
base64.decodebytes(
|
base64.decodebytes(
|
||||||
@ -1719,27 +1714,24 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
"""
|
def image_to_dataURL(image: ImageType, image_format:str="PNG") -> str:
|
||||||
Converts an image into a base64 image dataURL.
|
"""
|
||||||
"""
|
Converts an image into a base64 image dataURL.
|
||||||
|
"""
|
||||||
|
|
||||||
def image_to_dataURL(image: ImageType) -> str:
|
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format=image_format)
|
||||||
image_base64 = "data:image/png;base64," + base64.b64encode(
|
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||||
|
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||||
buffered.getvalue()
|
buffered.getvalue()
|
||||||
).decode("UTF-8")
|
).decode("UTF-8")
|
||||||
return image_base64
|
return image_base64
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Converts a base64 image dataURL into bytes.
|
|
||||||
The dataURL is split on the first commma.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def dataURL_to_bytes(dataURL: str) -> bytes:
|
def dataURL_to_bytes(dataURL: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Converts a base64 image dataURL into bytes.
|
||||||
|
The dataURL is split on the first comma.
|
||||||
|
"""
|
||||||
return base64.decodebytes(
|
return base64.decodebytes(
|
||||||
bytes(
|
bytes(
|
||||||
dataURL.split(",", 1)[1],
|
dataURL.split(",", 1)[1],
|
||||||
@ -1748,11 +1740,6 @@ def dataURL_to_bytes(dataURL: str) -> bytes:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Pastes an image onto another with a bounding box.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def paste_image_into_bounding_box(
|
def paste_image_into_bounding_box(
|
||||||
recipient_image: ImageType,
|
recipient_image: ImageType,
|
||||||
donor_image: ImageType,
|
donor_image: ImageType,
|
||||||
@ -1761,23 +1748,24 @@ def paste_image_into_bounding_box(
|
|||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
) -> ImageType:
|
) -> ImageType:
|
||||||
|
"""
|
||||||
|
Pastes an image onto another with a bounding box.
|
||||||
|
"""
|
||||||
with recipient_image as im:
|
with recipient_image as im:
|
||||||
bounds = (x, y, x + width, y + height)
|
bounds = (x, y, x + width, y + height)
|
||||||
im.paste(donor_image, bounds)
|
im.paste(donor_image, bounds)
|
||||||
return recipient_image
|
return recipient_image
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Saves a thumbnail of an image, returning its path.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def save_thumbnail(
|
def save_thumbnail(
|
||||||
image: ImageType,
|
image: ImageType,
|
||||||
filename: str,
|
filename: str,
|
||||||
path: str,
|
path: str,
|
||||||
size: int = 256,
|
size: int = 256,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""
|
||||||
|
Saves a thumbnail of an image, returning its path.
|
||||||
|
"""
|
||||||
base_filename = os.path.splitext(filename)[0]
|
base_filename = os.path.splitext(filename)[0]
|
||||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||||
|
|
||||||
|
212
invokeai/renderer1.py
Normal file
212
invokeai/renderer1.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
'''
|
||||||
|
Simple class hierarchy
|
||||||
|
'''
|
||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
import diffusers
|
||||||
|
import importlib
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Type
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
|
import invokeai.assets as image_assets
|
||||||
|
from ldm.invoke.globals import global_config_dir
|
||||||
|
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RendererBasicParams:
|
||||||
|
width: int=512
|
||||||
|
height: int=512
|
||||||
|
cfg_scale: int=7.5
|
||||||
|
steps: int=20
|
||||||
|
ddim_eta: float=0.0
|
||||||
|
model: str='stable-diffusion-1.5'
|
||||||
|
scheduler: int='ddim'
|
||||||
|
precision: str='float16'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RendererOutput:
|
||||||
|
image: Image
|
||||||
|
seed: int
|
||||||
|
model_name: str
|
||||||
|
model_hash: str
|
||||||
|
params: RendererBasicParams
|
||||||
|
|
||||||
|
class InvokeAIRenderer(metaclass=ABCMeta):
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: RendererBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager=model_manager
|
||||||
|
self.params=params
|
||||||
|
|
||||||
|
def render(self,
|
||||||
|
prompt: str='',
|
||||||
|
callback: callable=None,
|
||||||
|
iterations: int=1,
|
||||||
|
step_callback: callable=None,
|
||||||
|
**keyword_args,
|
||||||
|
)->List[RendererOutput]:
|
||||||
|
|
||||||
|
results = []
|
||||||
|
model_name = self.params.model or self.model_manager.current_model
|
||||||
|
model_info: dict = self.model_manager.get_model(model_name)
|
||||||
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=self.params.scheduler
|
||||||
|
)
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
|
|
||||||
|
def _wrap_results(image: Image, seed: int, **kwargs):
|
||||||
|
nonlocal results
|
||||||
|
output = RendererOutput(
|
||||||
|
image=image,
|
||||||
|
seed=seed,
|
||||||
|
model_name = model_name,
|
||||||
|
model_hash = model_hash,
|
||||||
|
params=copy.copy(self.params)
|
||||||
|
)
|
||||||
|
if callback:
|
||||||
|
callback(output)
|
||||||
|
results.append(output)
|
||||||
|
|
||||||
|
generator = self.load_generator(model, self._generator_name())
|
||||||
|
generator.generate(prompt,
|
||||||
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
|
image_callback=_wrap_results,
|
||||||
|
sampler=scheduler,
|
||||||
|
iterations=iterations,
|
||||||
|
**dataclasses.asdict(self.params),
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||||
|
module_name = f'ldm.invoke.generator.{class_name.lower()}'
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
constructor = getattr(module, class_name)
|
||||||
|
return constructor(model, self.params.precision)
|
||||||
|
|
||||||
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
'''
|
||||||
|
In derived classes will return the name of the generator to use.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Txt2Img(InvokeAIRenderer):
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Txt2Img'
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Img2Img(InvokeAIRenderer):
|
||||||
|
def render(self,
|
||||||
|
init_image: Image,
|
||||||
|
strength: float=0.75,
|
||||||
|
**keyword_args
|
||||||
|
)->List[RendererOutput]:
|
||||||
|
return super().render(init_image=init_image,
|
||||||
|
strength=strength,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Img2Img'
|
||||||
|
|
||||||
|
class RendererFactory(object):
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: RendererBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def renderer(self, rendererclass: Type[InvokeAIRenderer], **keyword_args)->InvokeAIRenderer:
|
||||||
|
return rendererclass(self.model_manager,
|
||||||
|
self.params,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- testing ---
|
||||||
|
def main():
|
||||||
|
config_file = Path(global_config_dir()) / "models.yaml"
|
||||||
|
model_manager = ModelManager(OmegaConf.load(config_file),
|
||||||
|
precision='float16',
|
||||||
|
device_type=choose_torch_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = RendererBasicParams(
|
||||||
|
model = 'stable-diffusion-1.5',
|
||||||
|
steps = 30,
|
||||||
|
scheduler = 'k_lms',
|
||||||
|
cfg_scale = 8.0,
|
||||||
|
height = 640,
|
||||||
|
width = 640
|
||||||
|
)
|
||||||
|
factory = RendererFactory(model_manager, params)
|
||||||
|
|
||||||
|
print ('=== TXT2IMG TEST ===')
|
||||||
|
txt2img = factory.renderer(Txt2Img)
|
||||||
|
renderer_outputs = txt2img.render(prompt='banana sushi',
|
||||||
|
iterations=2,
|
||||||
|
callback=lambda outputs: print(f'SUCCESS: got image with seed {outputs.seed}')
|
||||||
|
)
|
||||||
|
|
||||||
|
for r in renderer_outputs:
|
||||||
|
print(f'image={r.image}, seed={r.seed}, model={r.model_name}, hash={r.model_hash}')
|
||||||
|
|
||||||
|
|
||||||
|
print ('\n=== IMG2IMG TEST ===')
|
||||||
|
img2img = factory.renderer(Img2Img)
|
||||||
|
try:
|
||||||
|
renderer_outputs = img2img.render(prompt='basket of sushi')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'SUCCESS: Calling img2img() without required parameter rejected {str(e)}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_image = Path(__file__,'../../docs/assets/still-life-inpainted.png')
|
||||||
|
renderer_outputs = img2img.render(prompt='basket of sushi',
|
||||||
|
strength=0.5,
|
||||||
|
init_image=Image.open(test_image))
|
||||||
|
except Exception as e:
|
||||||
|
print(f'FAILURE: {str(e)}')
|
||||||
|
|
||||||
|
print('Image saved as "ugly-sushi.png"')
|
||||||
|
renderer_outputs[0].image.save('ugly-sushi.png')
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
main()
|
187
invokeai/renderer2.py
Normal file
187
invokeai/renderer2.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
'''
|
||||||
|
Simple class hierarchy
|
||||||
|
'''
|
||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
import diffusers
|
||||||
|
import importlib
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Type
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
|
import invokeai.assets as image_assets
|
||||||
|
from ldm.invoke.globals import global_config_dir
|
||||||
|
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RendererBasicParams:
|
||||||
|
width: int=512
|
||||||
|
height: int=512
|
||||||
|
cfg_scale: int=7.5
|
||||||
|
steps: int=20
|
||||||
|
ddim_eta: float=0.0
|
||||||
|
model: str='stable-diffusion-1.5'
|
||||||
|
scheduler: int='ddim'
|
||||||
|
precision: str='float16'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RendererOutput:
|
||||||
|
image: Image
|
||||||
|
seed: int
|
||||||
|
model_name: str
|
||||||
|
model_hash: str
|
||||||
|
params: RendererBasicParams
|
||||||
|
|
||||||
|
class InvokeAIRenderer(metaclass=ABCMeta):
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: RendererBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager=model_manager
|
||||||
|
self.params=params
|
||||||
|
|
||||||
|
def render(self,
|
||||||
|
prompt: str='',
|
||||||
|
callback: callable=None,
|
||||||
|
step_callback: callable=None,
|
||||||
|
**keyword_args,
|
||||||
|
)->List[RendererOutput]:
|
||||||
|
|
||||||
|
model_name = self.params.model or self.model_manager.current_model
|
||||||
|
model_info: dict = self.model_manager.get_model(model_name)
|
||||||
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=self.params.scheduler
|
||||||
|
)
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
|
|
||||||
|
def _wrap_results(image: Image, seed: int, **kwargs):
|
||||||
|
nonlocal results
|
||||||
|
results.append(output)
|
||||||
|
|
||||||
|
generator = self.load_generator(model, self._generator_name())
|
||||||
|
while True:
|
||||||
|
results = generator.generate(prompt,
|
||||||
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
|
sampler=scheduler,
|
||||||
|
**dataclasses.asdict(self.params),
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
output = RendererOutput(
|
||||||
|
image=results[0][0],
|
||||||
|
seed=results[0][1],
|
||||||
|
model_name = model_name,
|
||||||
|
model_hash = model_hash,
|
||||||
|
params=copy.copy(self.params)
|
||||||
|
)
|
||||||
|
if callback:
|
||||||
|
callback(output)
|
||||||
|
yield output
|
||||||
|
|
||||||
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||||
|
module_name = f'ldm.invoke.generator.{class_name.lower()}'
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
constructor = getattr(module, class_name)
|
||||||
|
return constructor(model, self.params.precision)
|
||||||
|
|
||||||
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
'''
|
||||||
|
In derived classes will return the name of the generator to use.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Txt2Img(InvokeAIRenderer):
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Txt2Img'
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Img2Img(InvokeAIRenderer):
|
||||||
|
def render(self,
|
||||||
|
init_image: Image,
|
||||||
|
strength: float=0.75,
|
||||||
|
**keyword_args
|
||||||
|
)->List[RendererOutput]:
|
||||||
|
return super().render(init_image=init_image,
|
||||||
|
strength=strength,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Img2Img'
|
||||||
|
|
||||||
|
class RendererFactory(object):
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: RendererBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def renderer(self, rendererclass: Type[InvokeAIRenderer], **keyword_args)->InvokeAIRenderer:
|
||||||
|
return rendererclass(self.model_manager,
|
||||||
|
self.params,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- testing ---
|
||||||
|
def main():
|
||||||
|
config_file = Path(global_config_dir()) / "models.yaml"
|
||||||
|
model_manager = ModelManager(OmegaConf.load(config_file),
|
||||||
|
precision='float16',
|
||||||
|
device_type=choose_torch_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = RendererBasicParams(
|
||||||
|
model = 'stable-diffusion-1.5',
|
||||||
|
steps = 30,
|
||||||
|
scheduler = 'k_lms',
|
||||||
|
cfg_scale = 8.0,
|
||||||
|
height = 640,
|
||||||
|
width = 640
|
||||||
|
)
|
||||||
|
factory = RendererFactory(model_manager, params)
|
||||||
|
|
||||||
|
print ('=== TXT2IMG TEST ===')
|
||||||
|
txt2img = factory.renderer(Txt2Img)
|
||||||
|
outputs = txt2img.render(prompt='banana sushi')
|
||||||
|
for i in range(3):
|
||||||
|
output = next(outputs)
|
||||||
|
print(f'image={output.image}, seed={output.seed}, model={output.model_name}, hash={output.model_hash}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
main()
|
191
invokeai/renderer3.py
Normal file
191
invokeai/renderer3.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
'''
|
||||||
|
Simple class hierarchy
|
||||||
|
'''
|
||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
import diffusers
|
||||||
|
import importlib
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Type
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
|
import invokeai.assets as image_assets
|
||||||
|
from ldm.invoke.globals import global_config_dir
|
||||||
|
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ldm.invoke.model_manager2 import ModelManager
|
||||||
|
# ^^^^^^^^^^^^^^ note alternative version
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RendererBasicParams:
|
||||||
|
width: int=512
|
||||||
|
height: int=512
|
||||||
|
cfg_scale: int=7.5
|
||||||
|
steps: int=20
|
||||||
|
ddim_eta: float=0.0
|
||||||
|
model: str='stable-diffusion-1.5'
|
||||||
|
scheduler: int='ddim'
|
||||||
|
precision: str='float16'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RendererOutput:
|
||||||
|
image: Image
|
||||||
|
seed: int
|
||||||
|
model_name: str
|
||||||
|
model_hash: str
|
||||||
|
params: RendererBasicParams
|
||||||
|
|
||||||
|
class InvokeAIRenderer(metaclass=ABCMeta):
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: RendererBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager=model_manager
|
||||||
|
self.params=params
|
||||||
|
|
||||||
|
def render(self,
|
||||||
|
prompt: str='',
|
||||||
|
callback: callable=None,
|
||||||
|
iterations: int=1,
|
||||||
|
step_callback: callable=None,
|
||||||
|
**keyword_args,
|
||||||
|
)->List[RendererOutput]:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# closure
|
||||||
|
def _wrap_results(image: Image, seed: int, **kwargs):
|
||||||
|
nonlocal results
|
||||||
|
output = RendererOutput(
|
||||||
|
image=image,
|
||||||
|
seed=seed,
|
||||||
|
model_name = model_name,
|
||||||
|
model_hash = model_hash,
|
||||||
|
params=copy.copy(self.params)
|
||||||
|
)
|
||||||
|
if callback:
|
||||||
|
callback(output)
|
||||||
|
results.append(output)
|
||||||
|
|
||||||
|
model_name = self.params.model or self.model_manager.current_model
|
||||||
|
print(f'** OUTSIDE CONTEXT: Reference count for {model_name} = {self.model_manager.refcount(model_name)}**')
|
||||||
|
|
||||||
|
with self.model_manager.get_model(model_name) as model_info:
|
||||||
|
print(f'** INSIDE CONTEXT: Reference count for {model_name} = {self.model_manager.refcount(model_name)} **')
|
||||||
|
|
||||||
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=self.params.scheduler
|
||||||
|
)
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
|
|
||||||
|
generator = self.load_generator(model, self._generator_name())
|
||||||
|
generator.generate(prompt,
|
||||||
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
|
image_callback=_wrap_results,
|
||||||
|
sampler=scheduler,
|
||||||
|
iterations=iterations,
|
||||||
|
**dataclasses.asdict(self.params),
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'AGAIN OUTSIDE CONTEXT: Reference count for {model_name} = {self.model_manager.refcount(model_name)}')
|
||||||
|
return results
|
||||||
|
|
||||||
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||||
|
module_name = f'ldm.invoke.generator.{class_name.lower()}'
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
constructor = getattr(module, class_name)
|
||||||
|
return constructor(model, self.params.precision)
|
||||||
|
|
||||||
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
'''
|
||||||
|
In derived classes will return the name of the generator to use.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Txt2Img(InvokeAIRenderer):
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Txt2Img'
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Img2Img(InvokeAIRenderer):
|
||||||
|
def render(self,
|
||||||
|
init_image: Image,
|
||||||
|
strength: float=0.75,
|
||||||
|
**keyword_args
|
||||||
|
)->List[RendererOutput]:
|
||||||
|
return super().render(init_image=init_image,
|
||||||
|
strength=strength,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Img2Img'
|
||||||
|
|
||||||
|
class RendererFactory(object):
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: RendererBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def renderer(self, rendererclass: Type[InvokeAIRenderer], **keyword_args)->InvokeAIRenderer:
|
||||||
|
return rendererclass(self.model_manager,
|
||||||
|
self.params,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- testing ---
|
||||||
|
def main():
|
||||||
|
config_file = Path(global_config_dir()) / "models.yaml"
|
||||||
|
model_manager = ModelManager(OmegaConf.load(config_file),
|
||||||
|
precision='float16',
|
||||||
|
device_type=choose_torch_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = RendererBasicParams(
|
||||||
|
model = 'stable-diffusion-1.5',
|
||||||
|
steps = 30,
|
||||||
|
scheduler = 'k_lms',
|
||||||
|
cfg_scale = 8.0,
|
||||||
|
height = 640,
|
||||||
|
width = 640
|
||||||
|
)
|
||||||
|
factory = RendererFactory(model_manager, params)
|
||||||
|
outputs = factory.renderer(Txt2Img).render(prompt='banana sushi')
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
main()
|
@ -62,7 +62,7 @@ class Generator:
|
|||||||
self.variation_amount = variation_amount
|
self.variation_amount = variation_amount
|
||||||
self.with_variations = with_variations
|
self.with_variations = with_variations
|
||||||
|
|
||||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
def generate(self,prompt,width,height,sampler, init_image=None, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
|
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
|
@ -55,7 +55,6 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
|||||||
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
|
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -79,10 +78,12 @@ class ModelManager(object):
|
|||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
self.max_loaded_models = max_loaded_models
|
self.max_loaded_models = max_loaded_models
|
||||||
self.models = {}
|
self.models = {}
|
||||||
|
self.in_use = {} # ref counts of models in use, for locking some day
|
||||||
self.stack = [] # this is an LRU FIFO
|
self.stack = [] # this is an LRU FIFO
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
self.sequential_offload = sequential_offload
|
self.sequential_offload = sequential_offload
|
||||||
|
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
|
1399
ldm/invoke/model_manager2.py
Normal file
1399
ldm/invoke/model_manager2.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user