mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model manager defaults to consistent values of device and precision
This commit is contained in:
parent
5d37fa6e36
commit
b679a6ba37
@ -6,7 +6,10 @@ from .generator import (
|
|||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGeneratorFactory,
|
InvokeAIGeneratorFactory,
|
||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorOutput
|
InvokeAIGeneratorOutput,
|
||||||
|
Txt2Img,
|
||||||
|
Img2Img,
|
||||||
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager
|
||||||
from .args import Args
|
from .args import Args
|
||||||
|
@ -5,6 +5,7 @@ including img2img, txt2img, and inpaint
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import itertools
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import diffusers
|
import diffusers
|
||||||
import os
|
import os
|
||||||
@ -20,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import List, Type
|
from typing import List, Type, Iterator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class InvokeAIGeneratorOutput:
|
|||||||
class InvokeAIGeneratorFactory(object):
|
class InvokeAIGeneratorFactory(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
params: InvokeAIGeneratorBasicParams
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.params = params
|
self.params = params
|
||||||
@ -115,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams,
|
||||||
):
|
):
|
||||||
self.model_manager=model_manager
|
self.model_manager=model_manager
|
||||||
self.params=params
|
self.params=params
|
||||||
@ -124,9 +125,30 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
prompt: str='',
|
prompt: str='',
|
||||||
callback: callable=None,
|
callback: callable=None,
|
||||||
step_callback: callable=None,
|
step_callback: callable=None,
|
||||||
|
iterations: int=1,
|
||||||
**keyword_args,
|
**keyword_args,
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
|
'''
|
||||||
|
Return an iterator across the indicated number of generations.
|
||||||
|
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||||
|
object. Use like this:
|
||||||
|
|
||||||
|
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
|
||||||
|
for result in outputs:
|
||||||
|
print(result.image, result.seed)
|
||||||
|
|
||||||
|
In the typical case of wanting to get just a single image, iterations
|
||||||
|
defaults to 1 and do:
|
||||||
|
|
||||||
|
output = next(txt2img.generate(prompt='banana sushi')
|
||||||
|
|
||||||
|
Pass None to get an infinite iterator.
|
||||||
|
|
||||||
|
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||||
|
for o in outputs:
|
||||||
|
print(o.image, o.seed)
|
||||||
|
|
||||||
|
'''
|
||||||
model_name = self.params.model_name or self.model_manager.current_model
|
model_name = self.params.model_name or self.model_manager.current_model
|
||||||
model_info: dict = self.model_manager.get_model(model_name)
|
model_info: dict = self.model_manager.get_model(model_name)
|
||||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
@ -149,8 +171,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
generator_args = dataclasses.asdict(self.params)
|
generator_args = dataclasses.asdict(self.params)
|
||||||
generator_args.update(keyword_args)
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
while True:
|
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||||
|
for i in iteration_count:
|
||||||
results = generator.generate(prompt,
|
results = generator.generate(prompt,
|
||||||
conditioning=(uc, c, extra_conditioning_info),
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
sampler=scheduler,
|
sampler=scheduler,
|
||||||
|
@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util import CPU_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = 1
|
||||||
@ -51,23 +50,28 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
|||||||
}
|
}
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
|
'''
|
||||||
|
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||||
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf,
|
config: OmegaConf|Path,
|
||||||
device_type: torch.device = CPU_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file,
|
Initialize with the path to the models.yaml config file or
|
||||||
the torch device type, and precision. The optional
|
an initialized OmegaConf dictionary. Optional parameters
|
||||||
min_avail_mem argument specifies how much unused system
|
are the torch device type, precision, max_loaded_models,
|
||||||
(CPU) memory to preserve. The cache of models in RAM will
|
and sequential_offload boolean. Note that the default device
|
||||||
grow until this value is approached. Default is 2G.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
# prevent nasty-looking CLIP log message
|
# prevent nasty-looking CLIP log message
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
if not isinstance(config, DictConfig):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
@ -557,7 +561,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
model_name = model_name or Path(repo_or_path).stem
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
model_description = (
|
model_description = (
|
||||||
model_description or f"Imported diffusers model {model_name}"
|
description or f"Imported diffusers model {model_name}"
|
||||||
)
|
)
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
description=model_description,
|
description=model_description,
|
||||||
|
Loading…
Reference in New Issue
Block a user