model manager defaults to consistent values of device and precision

This commit is contained in:
Lincoln Stein 2023-03-09 01:09:54 -05:00
parent 5d37fa6e36
commit b679a6ba37
3 changed files with 47 additions and 17 deletions

View File

@ -6,7 +6,10 @@ from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorOutput
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint
)
from .model_management import ModelManager
from .args import Args

View File

@ -5,6 +5,7 @@ including img2img, txt2img, and inpaint
from __future__ import annotations
import importlib
import itertools
import dataclasses
import diffusers
import os
@ -20,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Type
from typing import List, Type, Iterator
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -77,7 +78,7 @@ class InvokeAIGeneratorOutput:
class InvokeAIGeneratorFactory(object):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager = model_manager
self.params = params
@ -115,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
params: InvokeAIGeneratorBasicParams,
):
self.model_manager=model_manager
self.params=params
@ -124,9 +125,30 @@ class InvokeAIGenerator(metaclass=ABCMeta):
prompt: str='',
callback: callable=None,
step_callback: callable=None,
iterations: int=1,
**keyword_args,
)->List[InvokeAIGeneratorOutput]:
)->Iterator[InvokeAIGeneratorOutput]:
'''
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
object. Use like this:
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
for result in outputs:
print(result.image, result.seed)
In the typical case of wanting to get just a single image, iterations
defaults to 1 and do:
output = next(txt2img.generate(prompt='banana sushi')
Pass None to get an infinite iterator.
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
for o in outputs:
print(o.image, o.seed)
'''
model_name = self.params.model_name or self.model_manager.current_model
model_info: dict = self.model_manager.get_model(model_name)
model:StableDiffusionGeneratorPipeline = model_info['model']
@ -149,8 +171,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
while True:
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
sampler=scheduler,

View File

@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CPU_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
@ -51,23 +50,28 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
}
class ModelManager(object):
'''
Model manager handles loading, caching, importing, deleting, converting, and editing models.
'''
def __init__(
self,
config: OmegaConf,
device_type: torch.device = CPU_DEVICE,
config: OmegaConf|Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
):
"""
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
min_avail_mem argument specifies how much unused system
(CPU) memory to preserve. The cache of models in RAM will
grow until this value is approached. Default is 2G.
Initialize with the path to the models.yaml config file or
an initialized OmegaConf dictionary. Optional parameters
are the torch device type, precision, max_loaded_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
# prevent nasty-looking CLIP log message
transformers.logging.set_verbosity_error()
if not isinstance(config, DictConfig):
config = OmegaConf.load(config)
self.config = config
self.precision = precision
self.device = torch.device(device_type)
@ -557,7 +561,7 @@ class ModelManager(object):
"""
model_name = model_name or Path(repo_or_path).stem
model_description = (
model_description or f"Imported diffusers model {model_name}"
description or f"Imported diffusers model {model_name}"
)
new_config = dict(
description=model_description,