mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/configure-max-cache-size
This commit is contained in:
@ -193,7 +193,10 @@ class ModelInstall(object):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
|
@ -3,15 +3,13 @@ from __future__ import annotations
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -785,7 +785,7 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
|
||||
@ -794,7 +794,8 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
|
@ -78,7 +78,6 @@ class ModelProbe(object):
|
||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||
else:
|
||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||
|
||||
model_info = None
|
||||
try:
|
||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||
@ -105,7 +104,7 @@ class ModelProbe(object):
|
||||
) else 512,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@ -127,6 +126,8 @@ class ModelProbe(object):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -137,7 +138,7 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
@ -167,7 +168,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError("Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
|
Reference in New Issue
Block a user