mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes
This commit is contained in:
parent
9fa78443de
commit
36eb1bd893
@ -64,7 +64,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
base_model = BaseModelType.StableDiffusion1_5 # TODO:
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
@ -116,7 +116,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.Scheduler,
|
submodel=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
@ -125,13 +125,13 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.Tokenizer,
|
submodel=SubModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.TextEncoder,
|
submodel=SubModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
@ -140,7 +140,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.Vae,
|
submodel=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -162,7 +162,8 @@ class ModelCache(object):
|
|||||||
if model_info_key not in self.model_infos:
|
if model_info_key not in self.model_infos:
|
||||||
self.model_infos[model_info_key] = model_class(
|
self.model_infos[model_info_key] = model_class(
|
||||||
model_path,
|
model_path,
|
||||||
model_class,
|
base_model,
|
||||||
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_infos[model_info_key]
|
return self.model_infos[model_info_key]
|
||||||
@ -208,7 +209,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# clean memory to make MemoryUsage() more accurate
|
# clean memory to make MemoryUsage() more accurate
|
||||||
gc.collect()
|
gc.collect()
|
||||||
model = model_info.get_model(submodel, torch_dtype=self.precision)
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||||
if mem_used := model_info.get_size(submodel):
|
if mem_used := model_info.get_size(submodel):
|
||||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||||
|
|
||||||
|
@ -430,10 +430,10 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
if submodel is not None and submodel in model_config:
|
if submodel_type is not None and submodel_type in model_config:
|
||||||
model_path = model_config[submodel]
|
model_path = model_config[submodel_type]
|
||||||
model_type = submodel
|
model_type = submodel_type
|
||||||
submodel = None
|
submodel_type = None
|
||||||
|
|
||||||
dst_convert_path = None # TODO:
|
dst_convert_path = None # TODO:
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
@ -443,9 +443,11 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
model_path,
|
model_path=model_path,
|
||||||
model_class,
|
model_class=model_class,
|
||||||
submodel_type,
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=submodel_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
hash = "<NO_HASH>" # TODO:
|
hash = "<NO_HASH>" # TODO:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -102,7 +102,8 @@ class StableDiffusion15Model(DiffusersModel):
|
|||||||
elif kwargs["format"] == "diffusers":
|
elif kwargs["format"] == "diffusers":
|
||||||
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
unet_config = json.loads(unet_config_path)
|
with open(unet_config_path, "r") as f:
|
||||||
|
unet_config = json.loads(f.read())
|
||||||
in_channels = unet_config['in_channels']
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -135,10 +136,14 @@ class StableDiffusion15Model(DiffusersModel):
|
|||||||
return "checkpoint"
|
return "checkpoint"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||||
cfg = cls.build_config(**config)
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
if isinstance(cfg, cls.CheckpointConfig):
|
return _convert_ckpt_and_cache(
|
||||||
return _convert_ckpt_and_cache(cfg) # TODO: args
|
version=BaseModelType.StableDiffusion1_5,
|
||||||
|
config=config.dict(),
|
||||||
|
in_path=model_path,
|
||||||
|
out_path=dst_cache_path,
|
||||||
|
) # TODO: args
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
@ -154,10 +159,22 @@ class StableDiffusion2BaseModel(StableDiffusion15Model):
|
|||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
class StableDiffusion2Model(DiffusersModel):
|
@classmethod
|
||||||
|
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||||
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
|
return _convert_ckpt_and_cache(
|
||||||
|
version=BaseModelType.StableDiffusion2Base,
|
||||||
|
config=config.dict(),
|
||||||
|
in_path=model_path,
|
||||||
|
out_path=dst_cache_path,
|
||||||
|
) # TODO: args
|
||||||
|
else:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
class StableDiffusion2Model(StableDiffusion15Model):
|
||||||
|
|
||||||
# TODO: str -> Path?
|
# TODO: str -> Path?
|
||||||
# overwrite configs
|
# TODO: check that configs overwriten
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
@ -174,22 +191,49 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
# skip StableDiffusion15Model __init__
|
# skip StableDiffusion15Model __init__
|
||||||
assert base_model == BaseModelType.StableDiffusion2
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
assert model_type == ModelType.Pipeline
|
assert model_type == ModelType.Pipeline
|
||||||
super().__init__(
|
# skip StableDiffusion15Model __init__
|
||||||
|
super(StableDiffusion15Model, self).__init__(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
base_model=BaseModelType.StableDiffusion2,
|
base_model=BaseModelType.StableDiffusion2,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||||
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
|
return _convert_ckpt_and_cache(
|
||||||
|
version=BaseModelType.StableDiffusion2,
|
||||||
|
config=config.dict(),
|
||||||
|
in_path=model_path,
|
||||||
|
out_path=dst_cache_path,
|
||||||
|
) # TODO: args
|
||||||
|
else:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
# TODO: rework
|
||||||
DictConfig = dict
|
DictConfig = dict
|
||||||
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
def _convert_ckpt_and_cache(
|
||||||
|
self,
|
||||||
|
version: BaseModelType,
|
||||||
|
mconfig: dict, # TODO:
|
||||||
|
in_path: str,
|
||||||
|
out_path: str,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
diffusers, cache it to disk, and return Path to converted
|
diffusers, cache it to disk, and return Path to converted
|
||||||
file. If already on disk then just returns Path.
|
file. If already on disk then just returns Path.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
#if "config" not in mconfig:
|
||||||
|
# if version == BaseModelType.StableDiffusion1_5:
|
||||||
|
#if
|
||||||
|
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
|
||||||
|
|
||||||
|
|
||||||
weights = app_config.root_dir / mconfig.path
|
weights = app_config.root_dir / mconfig.path
|
||||||
config_file = app_config.root_dir / mconfig.config
|
config_file = app_config.root_dir / mconfig.config
|
||||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||||
|
Loading…
x
Reference in New Issue
Block a user