mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model manager defaults to consistent values of device and precision
This commit is contained in:
parent
5d37fa6e36
commit
b679a6ba37
@ -6,7 +6,10 @@ from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorFactory,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager
|
||||
from .args import Args
|
||||
|
@ -5,6 +5,7 @@ including img2img, txt2img, and inpaint
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import itertools
|
||||
import dataclasses
|
||||
import diffusers
|
||||
import os
|
||||
@ -20,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import List, Type
|
||||
from typing import List, Type, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
@ -77,7 +78,7 @@ class InvokeAIGeneratorOutput:
|
||||
class InvokeAIGeneratorFactory(object):
|
||||
def __init__(self,
|
||||
model_manager: ModelManager,
|
||||
params: InvokeAIGeneratorBasicParams
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.params = params
|
||||
@ -115,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self,
|
||||
model_manager: ModelManager,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
params: InvokeAIGeneratorBasicParams,
|
||||
):
|
||||
self.model_manager=model_manager
|
||||
self.params=params
|
||||
@ -124,9 +125,30 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
prompt: str='',
|
||||
callback: callable=None,
|
||||
step_callback: callable=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->List[InvokeAIGeneratorOutput]:
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
object. Use like this:
|
||||
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
|
||||
for result in outputs:
|
||||
print(result.image, result.seed)
|
||||
|
||||
In the typical case of wanting to get just a single image, iterations
|
||||
defaults to 1 and do:
|
||||
|
||||
output = next(txt2img.generate(prompt='banana sushi')
|
||||
|
||||
Pass None to get an infinite iterator.
|
||||
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||
for o in outputs:
|
||||
print(o.image, o.seed)
|
||||
|
||||
'''
|
||||
model_name = self.params.model_name or self.model_manager.current_model
|
||||
model_info: dict = self.model_manager.get_model(model_name)
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
@ -149,8 +171,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
while True:
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
sampler=scheduler,
|
||||
|
@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util import CPU_DEVICE, ask_user, download_with_resume
|
||||
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
@ -51,23 +50,28 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
}
|
||||
|
||||
class ModelManager(object):
|
||||
'''
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf,
|
||||
device_type: torch.device = CPU_DEVICE,
|
||||
config: OmegaConf|Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
Initialize with the path to the models.yaml config file or
|
||||
an initialized OmegaConf dictionary. Optional parameters
|
||||
are the torch device type, precision, max_loaded_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
if not isinstance(config, DictConfig):
|
||||
config = OmegaConf.load(config)
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
@ -557,7 +561,7 @@ class ModelManager(object):
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = (
|
||||
model_description or f"Imported diffusers model {model_name}"
|
||||
description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
|
Loading…
Reference in New Issue
Block a user