mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
remove factory pattern
Factory pattern is now removed. Typical usage of the InvokeAIGenerator is now: ``` from invokeai.backend.generator import ( InvokeAIGeneratorBasicParams, Txt2Img, Img2Img, Inpaint, ) params = InvokeAIGeneratorBasicParams( model_name = 'stable-diffusion-1.5', steps = 30, scheduler = 'k_lms', cfg_scale = 8.0, height = 640, width = 640 ) print ('=== TXT2IMG TEST ===') txt2img = Txt2Img(manager, params) outputs = txt2img.generate(prompt='banana sushi', iterations=2) for i in outputs: print(f'image={output.image}, seed={output.seed}, model={output.params.model_name}, hash={output.model_hash}, steps={output.params.steps}') ``` The `params` argument is optional, so if you wish to accept default parameters and selectively override them, just do this: ``` outputs = Txt2Img(manager).generate(prompt='banana sushi', steps=50, scheduler='k_heun', model_name='stable-diffusion-2.1' ) ```
This commit is contained in:
parent
c11e823ff3
commit
95954188b2
@ -4,7 +4,7 @@ import os
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.generate_initializer import get_generator_factory
|
from ..services.generate_initializer import get_model_manager
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -47,7 +47,7 @@ class ApiDependencies:
|
|||||||
# TODO: Use a logger
|
# TODO: Use a logger
|
||||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
generator_factory = get_generator_factory(args, config)
|
model_manager = get_model_manager(args, config)
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
|
|||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.generate_initializer import get_generator_factory
|
from .services.generate_initializer import get_model_manager
|
||||||
from .services.graph import EdgeConnection, GraphExecutionState
|
from .services.graph import EdgeConnection, GraphExecutionState
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -129,7 +129,7 @@ def invoke_cli():
|
|||||||
args = Args()
|
args = Args()
|
||||||
config = args.parse_args()
|
config = args.parse_args()
|
||||||
|
|
||||||
generator_factory = get_generator_factory(args, config)
|
model_manager = get_model_manager(args, config)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ def invoke_cli():
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
generator_factory=generator_factory,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
images=DiskImageStorage(output_folder),
|
images=DiskImageStorage(output_folder),
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
@ -18,7 +18,6 @@ SAMPLER_NAME_VALUES = Literal[
|
|||||||
tuple(InvokeAIGenerator.schedulers())
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToImageInvocation(BaseInvocation):
|
class TextToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image using text2img."""
|
"""Generates an image using text2img."""
|
||||||
@ -58,15 +57,8 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
factory = context.services.generator_factory
|
manager = context.services.model_manager
|
||||||
if self.model:
|
outputs = Txt2Img(manager).generate(
|
||||||
factory.model_name = self.model
|
|
||||||
else:
|
|
||||||
self.model = factory.model_name
|
|
||||||
|
|
||||||
txt2img = factory.make_generator(Txt2Img)
|
|
||||||
|
|
||||||
outputs = txt2img.generate(
|
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=step_callback,
|
step_callback=step_callback,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
@ -121,13 +113,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
factory = context.services.generator_factory
|
manager = context.services.model_manager
|
||||||
self.model = self.model or factory.model_name
|
|
||||||
factory.model_name = self.model
|
|
||||||
img2img = factory.make_generator(Img2Img)
|
|
||||||
|
|
||||||
generator_output = next(
|
generator_output = next(
|
||||||
img2img.generate(
|
Img2Img(manager).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_img=image,
|
init_img=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
@ -186,13 +174,9 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
factory = context.services.generator_factory
|
manager = context.services.model_manager
|
||||||
self.model = self.model or factory.model_name
|
|
||||||
factory.model_name = self.model
|
|
||||||
inpaint = factory.make_generator(Inpaint)
|
|
||||||
|
|
||||||
generator_output = next(
|
generator_output = next(
|
||||||
inpaint.generate(
|
Inpaint(manager).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_img=image,
|
init_img=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
|
@ -6,12 +6,12 @@ from argparse import Namespace
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
|
from ...backend import ModelManager
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
|
|
||||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||||
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
def get_model_manager(args, config) -> ModelManager:
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
@ -64,7 +64,7 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
|||||||
print(f"{e}. Aborting.")
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# creating an InvokeAIGeneratorFactory object:
|
# creating the model manager
|
||||||
try:
|
try:
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
precision = 'float16' if args.precision=='float16' \
|
precision = 'float16' if args.precision=='float16' \
|
||||||
@ -77,11 +77,6 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
|||||||
device_type=device,
|
device_type=device,
|
||||||
max_loaded_models=args.max_loaded_models,
|
max_loaded_models=args.max_loaded_models,
|
||||||
)
|
)
|
||||||
# TO DO: initialize and pass safety checker!!!
|
|
||||||
params = InvokeAIGeneratorBasicParams(
|
|
||||||
precision=precision,
|
|
||||||
)
|
|
||||||
factory = InvokeAIGeneratorFactory(model_manager, params)
|
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(args, e)
|
report_model_error(args, e)
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
@ -100,7 +95,7 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
|||||||
weights_directory=path,
|
weights_directory=path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return factory
|
return model_manager
|
||||||
|
|
||||||
def load_face_restoration(opt):
|
def load_face_restoration(opt):
|
||||||
try:
|
try:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
from invokeai.backend import InvokeAIGeneratorFactory
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
|
|||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
generator_factory: InvokeAIGeneratorFactory
|
model_manager: ModelManager
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
@ -21,14 +21,14 @@ class InvocationServices:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
generator_factory: InvokeAIGeneratorFactory,
|
model_manager: ModelManager,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
):
|
):
|
||||||
self.generator_factory = generator_factory
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
self.images = images
|
self.images = images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
@ -4,7 +4,6 @@ Initialization file for invokeai.backend
|
|||||||
from .generate import Generate
|
from .generate import Generate
|
||||||
from .generator import (
|
from .generator import (
|
||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGeneratorFactory,
|
|
||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorOutput,
|
InvokeAIGeneratorOutput,
|
||||||
Txt2Img,
|
Txt2Img,
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
Initialization file for the invokeai.generator package
|
Initialization file for the invokeai.generator package
|
||||||
"""
|
"""
|
||||||
from .base import (
|
from .base import (
|
||||||
InvokeAIGeneratorFactory,
|
|
||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGeneratorOutput,
|
InvokeAIGeneratorOutput,
|
||||||
|
@ -11,7 +11,8 @@ import diffusers
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta
|
||||||
|
from argparse import Namespace
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -21,7 +22,7 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import List, Type, Iterator
|
from typing import List, Iterator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
@ -35,13 +36,13 @@ downsampling = 8
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
|
model_name: str='stable-diffusion-1.5'
|
||||||
seed: int=None
|
seed: int=None
|
||||||
width: int=512
|
width: int=512
|
||||||
height: int=512
|
height: int=512
|
||||||
cfg_scale: int=7.5
|
cfg_scale: int=7.5
|
||||||
steps: int=20
|
steps: int=20
|
||||||
ddim_eta: float=0.0
|
ddim_eta: float=0.0
|
||||||
model_name: str='stable-diffusion-1.5'
|
|
||||||
scheduler: int='ddim'
|
scheduler: int='ddim'
|
||||||
precision: str='float16'
|
precision: str='float16'
|
||||||
perlin: float=0.0
|
perlin: float=0.0
|
||||||
@ -62,41 +63,8 @@ class InvokeAIGeneratorOutput:
|
|||||||
'''
|
'''
|
||||||
image: Image
|
image: Image
|
||||||
seed: int
|
seed: int
|
||||||
model_name: str
|
|
||||||
model_hash: str
|
model_hash: str
|
||||||
params: dict
|
params: Namespace
|
||||||
|
|
||||||
def __getattribute__(self,name):
|
|
||||||
try:
|
|
||||||
return object.__getattribute__(self, name)
|
|
||||||
except AttributeError:
|
|
||||||
params = object.__getattribute__(self, 'params')
|
|
||||||
if name in params:
|
|
||||||
return params[name]
|
|
||||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
|
||||||
|
|
||||||
class InvokeAIGeneratorFactory(object):
|
|
||||||
def __init__(self,
|
|
||||||
model_manager: ModelManager,
|
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
|
||||||
):
|
|
||||||
self.model_manager = model_manager
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
def make_generator(self, generatorclass: Type[InvokeAIGenerator], **keyword_args)->InvokeAIGenerator:
|
|
||||||
return generatorclass(self.model_manager,
|
|
||||||
self.params,
|
|
||||||
**keyword_args
|
|
||||||
)
|
|
||||||
|
|
||||||
# getter and setter shortcuts for commonly used parameters
|
|
||||||
@property
|
|
||||||
def model_name(self)->str:
|
|
||||||
return self.params.model_name
|
|
||||||
|
|
||||||
@model_name.setter
|
|
||||||
def model_name(self, model_name: str):
|
|
||||||
self.params.model_name=model_name
|
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
# old code that calls Generate will continue to work.
|
# old code that calls Generate will continue to work.
|
||||||
@ -116,7 +84,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
params: InvokeAIGeneratorBasicParams,
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
):
|
):
|
||||||
self.model_manager=model_manager
|
self.model_manager=model_manager
|
||||||
self.params=params
|
self.params=params
|
||||||
@ -149,23 +117,24 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
print(o.image, o.seed)
|
print(o.image, o.seed)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
model_name = self.params.model_name or self.model_manager.current_model
|
generator_args = dataclasses.asdict(self.params)
|
||||||
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
|
model_name = generator_args.get('model_name') or self.model_manager.current_model
|
||||||
model_info: dict = self.model_manager.get_model(model_name)
|
model_info: dict = self.model_manager.get_model(model_name)
|
||||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
model_hash = model_info['hash']
|
model_hash = model_info['hash']
|
||||||
scheduler: Scheduler = self.get_scheduler(
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
model=model,
|
model=model,
|
||||||
scheduler_name=self.params.scheduler
|
scheduler_name=generator_args.get('scheduler')
|
||||||
)
|
)
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
generator = self.load_generator(model, self._generator_name())
|
generator = self.load_generator(model, self._generator_name())
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(self.params.seed,
|
generator.set_variation(generator_args.get('seed'),
|
||||||
self.params.variation_amount,
|
generator_args.get('variation_amount'),
|
||||||
self.params.with_variations)
|
generator_args.get('with_variations')
|
||||||
|
)
|
||||||
generator_args = dataclasses.asdict(self.params)
|
|
||||||
generator_args.update(keyword_args)
|
|
||||||
|
|
||||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||||
for i in iteration_count:
|
for i in iteration_count:
|
||||||
@ -177,9 +146,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
output = InvokeAIGeneratorOutput(
|
output = InvokeAIGeneratorOutput(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
seed=results[0][1],
|
seed=results[0][1],
|
||||||
model_name = model_name,
|
|
||||||
model_hash = model_hash,
|
model_hash = model_hash,
|
||||||
params=generator_args,
|
params=Namespace(**generator_args),
|
||||||
)
|
)
|
||||||
if callback:
|
if callback:
|
||||||
callback(output)
|
callback(output)
|
||||||
@ -206,17 +174,18 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
@abstractmethod
|
@classmethod
|
||||||
def _generator_name(self)->str:
|
def _generator_name(cls):
|
||||||
'''
|
'''
|
||||||
In derived classes will return the name of the generator to use.
|
In derived classes return the name of the generator to apply.
|
||||||
|
If you don't override will return the name of the derived
|
||||||
|
class, which nicely parallels the generator class names.
|
||||||
'''
|
'''
|
||||||
pass
|
return cls.__name__
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Txt2Img(InvokeAIGenerator):
|
class Txt2Img(InvokeAIGenerator):
|
||||||
def _generator_name(self)->str:
|
pass
|
||||||
return 'Txt2Img'
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
@ -230,9 +199,6 @@ class Img2Img(InvokeAIGenerator):
|
|||||||
**keyword_args
|
**keyword_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generator_name(self)->str:
|
|
||||||
return 'Img2Img'
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
@ -266,9 +232,6 @@ class Inpaint(Img2Img):
|
|||||||
**keyword_args
|
**keyword_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generator_name(self)->str:
|
|
||||||
return 'Inpaint'
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
|
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = 1
|
||||||
|
Loading…
Reference in New Issue
Block a user