''' Simple class hierarchy ''' import copy import dataclasses import diffusers import importlib import traceback from abc import ABCMeta, abstractmethod from omegaconf import OmegaConf from pathlib import Path from PIL import Image from typing import List, Type from dataclasses import dataclass from diffusers.schedulers import SchedulerMixin as Scheduler import invokeai.assets as image_assets from ldm.invoke.globals import global_config_dir from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_manager2 import ModelManager # ^^^^^^^^^^^^^^ note alternative version from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ldm.invoke.devices import choose_torch_device @dataclass class RendererBasicParams: width: int=512 height: int=512 cfg_scale: int=7.5 steps: int=20 ddim_eta: float=0.0 model: str='stable-diffusion-1.5' scheduler: int='ddim' precision: str='float16' @dataclass class RendererOutput: image: Image seed: int model_name: str model_hash: str params: RendererBasicParams class InvokeAIRenderer(metaclass=ABCMeta): scheduler_map = dict( ddim=diffusers.DDIMScheduler, dpmpp_2=diffusers.DPMSolverMultistepScheduler, k_dpm_2=diffusers.KDPM2DiscreteScheduler, k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler, k_dpmpp_2=diffusers.DPMSolverMultistepScheduler, k_euler=diffusers.EulerDiscreteScheduler, k_euler_a=diffusers.EulerAncestralDiscreteScheduler, k_heun=diffusers.HeunDiscreteScheduler, k_lms=diffusers.LMSDiscreteScheduler, plms=diffusers.PNDMScheduler, ) def __init__(self, model_manager: ModelManager, params: RendererBasicParams ): self.model_manager=model_manager self.params=params def render(self, prompt: str='', callback: callable=None, iterations: int=1, step_callback: callable=None, **keyword_args, )->List[RendererOutput]: results = [] # closure def _wrap_results(image: Image, seed: int, **kwargs): nonlocal results output = RendererOutput( image=image, seed=seed, model_name = model_name, model_hash = model_hash, params=copy.copy(self.params) ) if callback: callback(output) results.append(output) model_name = self.params.model or self.model_manager.current_model print(f'** OUTSIDE CONTEXT: Reference count for {model_name} = {self.model_manager.refcount(model_name)}**') with self.model_manager.get_model(model_name) as model_info: print(f'** INSIDE CONTEXT: Reference count for {model_name} = {self.model_manager.refcount(model_name)} **') model:StableDiffusionGeneratorPipeline = model_info['model'] model_hash = model_info['hash'] scheduler: Scheduler = self.get_scheduler( model=model, scheduler_name=self.params.scheduler ) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) generator = self.load_generator(model, self._generator_name()) generator.generate(prompt, conditioning=(uc, c, extra_conditioning_info), image_callback=_wrap_results, sampler=scheduler, iterations=iterations, **dataclasses.asdict(self.params), **keyword_args ) print(f'AGAIN OUTSIDE CONTEXT: Reference count for {model_name} = {self.model_manager.refcount(model_name)}') return results def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str): module_name = f'ldm.invoke.generator.{class_name.lower()}' module = importlib.import_module(module_name) constructor = getattr(module, class_name) return constructor(model, self.params.precision) def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') scheduler = scheduler_class.from_config(model.scheduler.config) # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False return scheduler @abstractmethod def _generator_name(self)->str: ''' In derived classes will return the name of the generator to use. ''' pass # ------------------------------------ class Txt2Img(InvokeAIRenderer): def _generator_name(self)->str: return 'Txt2Img' # ------------------------------------ class Img2Img(InvokeAIRenderer): def render(self, init_image: Image, strength: float=0.75, **keyword_args )->List[RendererOutput]: return super().render(init_image=init_image, strength=strength, **keyword_args ) def _generator_name(self)->str: return 'Img2Img' class RendererFactory(object): def __init__(self, model_manager: ModelManager, params: RendererBasicParams ): self.model_manager = model_manager self.params = params def renderer(self, rendererclass: Type[InvokeAIRenderer], **keyword_args)->InvokeAIRenderer: return rendererclass(self.model_manager, self.params, **keyword_args ) # ---- testing --- def main(): config_file = Path(global_config_dir()) / "models.yaml" model_manager = ModelManager(OmegaConf.load(config_file), precision='float16', device_type=choose_torch_device(), ) params = RendererBasicParams( model = 'stable-diffusion-1.5', steps = 30, scheduler = 'k_lms', cfg_scale = 8.0, height = 640, width = 640 ) factory = RendererFactory(model_manager, params) outputs = factory.renderer(Txt2Img).render(prompt='banana sushi') if __name__=='__main__': main()