mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main, fix conflicts
This commit is contained in:
commit
54f3686e3b
@ -1,75 +1,30 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query, Body
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
model_name: str = Field(description="The name of the model")
|
||||
model_type: str = Field(description="The type of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['folder'] = 'folder'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class SafetensorsModelInfo(CkptModelInfo):
|
||||
format: Literal['safetensors'] = 'safetensors'
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
# base_model: str = Field(description="The base model")
|
||||
# model_type: str = Field(description="The model type")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses={200: {"description" : "The model was updated successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
400: {"description" : "Bad request"}
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict()
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
"/",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
items_to_import = {name}
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = info,
|
||||
status = "success",
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
else:
|
||||
logger.error(f'Model {name} not imported')
|
||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=424)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
@ -156,144 +139,95 @@ async def import_model(
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error("Model not found")
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except KeyError:
|
||||
logger.error(f"Model not found: {model_name}")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: {"description" : "Bad request" },
|
||||
404: { "description": "Model not found" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: { "description": "Incompatible models" },
|
||||
404: { "description": "One or more models not found" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names}")
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name or "+".join(model_names),
|
||||
alpha,
|
||||
interp,
|
||||
force)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -2,22 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
from invokeai.backend.model_management import (
|
||||
ModelManager,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
AddModelResult,
|
||||
SchedulerPredictionType,
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
)
|
||||
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC):
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
KeyErrorException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -301,12 +374,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -322,9 +402,28 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'add/update model {model_name}')
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
KeyError exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'update model {model_name}')
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise KeyError(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -336,8 +435,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
@ -389,9 +509,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -408,4 +528,35 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names = model_names,
|
||||
base_model = base_model,
|
||||
merged_model_name = merged_model_name,
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
@ -173,15 +173,19 @@ class ModelInstall(object):
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_import(path)
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
@ -194,62 +198,53 @@ class ModelInstall(object):
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update(self._install_path(path))
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path):self._install_path(path)})
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||
models_installed.update(self._install_repo(str(model_path_id_or_url)))
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update(self._install_url(model_path_id_or_url))
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
else:
|
||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
return models_installed
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||
try:
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
model_result = self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return {}
|
||||
return {str(path): model_result}
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
return self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str)->dict:
|
||||
# copy to a staging area, probe, import and delete
|
||||
def _install_url(self, url: str)->AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
if not location:
|
||||
@ -261,7 +256,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->dict:
|
||||
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|
@ -2,16 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -234,7 +234,7 @@ import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree
|
||||
from shutil import rmtree, move
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
|
||||
pass
|
||||
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after import")
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
@ -490,17 +490,32 @@ class ModelManager(object):
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
|
||||
def list_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model,model_type,model_name)
|
||||
return models[0] if models else None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_name: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
@ -545,10 +560,7 @@ class ModelManager(object):
|
||||
model_cfg = self.models.pop(model_key, None)
|
||||
|
||||
if model_cfg is None:
|
||||
self.logger.error(
|
||||
f"Unknown model {model_key}"
|
||||
)
|
||||
return
|
||||
raise KeyError(f"Unknown model {model_key}")
|
||||
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
@ -614,6 +626,7 @@ class ModelManager(object):
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models[model_key] = model_config
|
||||
self.commit()
|
||||
return AddModelResult(
|
||||
name = model_name,
|
||||
model_type = model_type,
|
||||
@ -621,6 +634,60 @@ class ModelManager(object):
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
'''
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||
# what submodeltype we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
|
||||
model = self.get_model(model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
**submodel,
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
model_attributes = info,
|
||||
clobber=True)
|
||||
except:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
@ -821,6 +888,10 @@ class ModelManager(object):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
|
||||
May return the following exceptions:
|
||||
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
||||
- ValueError - a corresponding model already exists
|
||||
'''
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
@ -830,11 +901,7 @@ class ModelManager(object):
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
try:
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
131
invokeai/backend/model_management/model_merge.py
Normal file
131
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
WeightedSum = "weighted_sum"
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
self.manager = manager
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_save (
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType,str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||
assert len(model_names) <= 2 or \
|
||||
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([config.root_path / info["path"]])
|
||||
|
||||
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||
logger.debug(f'interp = {interp}, merge_method={merge_method}')
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
attributes = dict(
|
||||
path = str(dump_path),
|
||||
description = f"Merge of models {', '.join(model_names)}",
|
||||
model_format = "diffusers",
|
||||
variant = ModelVariantType.Normal.value,
|
||||
vae = vae,
|
||||
)
|
||||
return self.manager.add_model(merged_model_name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
model_attributes = attributes,
|
||||
clobber = True
|
||||
)
|
@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
|
||||
|
@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
@ -20,99 +18,15 @@ from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
ModelMerger, MergeInterpolationMethod,
|
||||
ModelManager, ModelType, BaseModelType,
|
||||
)
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = config.model_conf_path
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
dest="model_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
'Models Built on SD-1.x',
|
||||
'Models Built on SD-2.x',
|
||||
],
|
||||
value=[self.current_base],
|
||||
columns = 4,
|
||||
max_height = 2,
|
||||
relx=8,
|
||||
scroll_exit = True,
|
||||
)
|
||||
self.base_select.on_changed = self._populate_models
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=5,
|
||||
rely=7,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
TextBox,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
max_height=3,
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
name="Force merge of models created by different diffusers library versions",
|
||||
labelColor="CONTROL",
|
||||
value=False,
|
||||
value=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
model_names=models,
|
||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
info["name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
]
|
||||
return sorted(model_names)
|
||||
|
||||
def _populate_models(self,value=None):
|
||||
base_model = tuple(BaseModelType)[value[0]]
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model1.values = self.model_names
|
||||
self.model2.values = self.model_names
|
||||
self.model3.values = models_plus_none
|
||||
|
||||
self.display()
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
def __init__(self, model_manager:ModelManager):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(config.model_conf_path)
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
self.model_manager = model_manager
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
||||
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
config.root = args.root_dir
|
||||
|
||||
cache_dir = config.cache_dir
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
config.parse_args(['--root',str(args.root_dir)])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { nodes, config, gallery } = getState();
|
||||
const { nodes, config } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
if (!gallery.ids.length) {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}
|
||||
|
@ -1,15 +1,16 @@
|
||||
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { memo } from 'react';
|
||||
import { RefObject, memo } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||
const {
|
||||
base50,
|
||||
base100,
|
||||
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<MultiSelect
|
||||
ref={inputRef}
|
||||
searchable={searchable}
|
||||
styles={() => ({
|
||||
label: {
|
||||
|
@ -6,10 +6,15 @@ import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
modelsApi,
|
||||
useGetMainModelsQuery,
|
||||
} from '../../services/api/endpoints/models';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ generation, system, batch }, activeTabName) => {
|
||||
(state, activeTabName) => {
|
||||
const { generation, system, batch } = state;
|
||||
const { shouldGenerateVariations, seedWeights, initialImage, seed } =
|
||||
generation;
|
||||
|
||||
@ -32,6 +37,13 @@ const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select()(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
}
|
||||
|
||||
// TODO: job queue
|
||||
// Cannot generate if already processing an image
|
||||
if (isProcessing) {
|
||||
|
@ -0,0 +1,33 @@
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { memo } from 'react';
|
||||
import { BiCode } from 'react-icons/bi';
|
||||
|
||||
type Props = {
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
const AddEmbeddingButton = (props: Props) => {
|
||||
const { onClick } = props;
|
||||
return (
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label="Add Embedding"
|
||||
tooltip="Add Embedding"
|
||||
icon={<BiCode />}
|
||||
sx={{
|
||||
p: 2,
|
||||
color: 'base.700',
|
||||
_hover: {
|
||||
color: 'base.550',
|
||||
},
|
||||
_active: {
|
||||
color: 'base.500',
|
||||
},
|
||||
}}
|
||||
variant="link"
|
||||
onClick={onClick}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddEmbeddingButton);
|
@ -0,0 +1,151 @@
|
||||
import {
|
||||
Flex,
|
||||
Popover,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import {
|
||||
PropsWithChildren,
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
|
||||
type EmbeddingSelectItem = {
|
||||
label: string;
|
||||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
onSelect: (v: string) => void;
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
const ParamEmbeddingPopover = (props: Props) => {
|
||||
const { onSelect, isOpen, onClose, children } = props;
|
||||
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!embeddingQueryData) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: EmbeddingSelectItem[] = [];
|
||||
|
||||
forEach(embeddingQueryData.entities, (embedding, _) => {
|
||||
if (!embedding) return;
|
||||
|
||||
data.push({
|
||||
value: embedding.name,
|
||||
label: embedding.name,
|
||||
description: embedding.description,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [embeddingQueryData]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
if (v.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v[0]);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
initialFocusRef={inputRef}
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
placement="bottom"
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
>
|
||||
<PopoverTrigger>{children}</PopoverTrigger>
|
||||
<PopoverContent
|
||||
sx={{
|
||||
p: 0,
|
||||
top: -1,
|
||||
shadow: 'dark-lg',
|
||||
borderColor: 'accent.300',
|
||||
borderWidth: '2px',
|
||||
borderStyle: 'solid',
|
||||
_dark: { borderColor: 'accent.400' },
|
||||
}}
|
||||
>
|
||||
<PopoverBody
|
||||
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
|
||||
>
|
||||
{data.length === 0 ? (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text
|
||||
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
|
||||
>
|
||||
No Embeddings Loaded
|
||||
</Text>
|
||||
</Flex>
|
||||
) : (
|
||||
<IAIMantineMultiSelect
|
||||
inputRef={inputRef}
|
||||
placeholder={'Add Embedding'}
|
||||
value={[]}
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No Matching Embeddings"
|
||||
itemComponent={SelectItem}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
)}
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamEmbeddingPopover;
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
@ -23,6 +23,7 @@ export const makeSelector = (image_name: string) =>
|
||||
({ gallery }) => {
|
||||
const isSelected = gallery.selection.includes(image_name);
|
||||
const selectionCount = gallery.selection.length;
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
selectionCount,
|
||||
@ -117,7 +118,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
resetIcon={<FaTrash />}
|
||||
resetTooltip="Delete image"
|
||||
imageSx={{ w: 'full', h: 'full' }}
|
||||
withResetIcon
|
||||
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
/>
|
||||
|
@ -182,6 +182,15 @@ const ImageGalleryContent = () => {
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickImagesCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||
dispatch(setGalleryView('images'));
|
||||
|
@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
|
||||
import {
|
||||
Lora,
|
||||
loraRemoved,
|
||||
loraWeightChanged,
|
||||
loraWeightReset,
|
||||
} from '../store/loraSlice';
|
||||
|
||||
type Props = {
|
||||
lora: Lora;
|
||||
@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
|
||||
dispatch(loraWeightReset(lora.id));
|
||||
}, [dispatch, lora.id]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Text } from '@chakra-ui/react';
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
|
||||
[dispatch, lorasQueryData?.entities]
|
||||
);
|
||||
|
||||
if (lorasQueryData?.ids.length === 0) {
|
||||
return (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||
No LoRAs Loaded
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IAIMantineMultiSelect
|
||||
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
|
||||
|
@ -8,7 +8,7 @@ export type Lora = {
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
||||
weight: 1,
|
||||
weight: 0.75,
|
||||
};
|
||||
|
||||
export type LoraState = {
|
||||
@ -38,9 +38,14 @@ export const loraSlice = createSlice({
|
||||
const { id, weight } = action.payload;
|
||||
state.loras[id].weight = weight;
|
||||
},
|
||||
loraWeightReset: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
state.loras[id].weight = defaultLoRAConfig.weight;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
|
||||
loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
||||
|
@ -1,29 +1,107 @@
|
||||
import { FormControl } from '@chakra-ui/react';
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamNegativeConditioning = () => {
|
||||
const negativePrompt = useAppSelector(
|
||||
(state: RootState) => state.generation.negativePrompt
|
||||
);
|
||||
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setNegativePrompt(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = negativePrompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += negativePrompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setNegativePrompt(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, negativePrompt]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<IAITextarea
|
||||
id="negativePrompt"
|
||||
name="negativePrompt"
|
||||
value={negativePrompt}
|
||||
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
|
||||
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||
fontSize="sm"
|
||||
minH={16}
|
||||
/>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="negativePrompt"
|
||||
name="negativePrompt"
|
||||
ref={promptRef}
|
||||
value={negativePrompt}
|
||||
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
fontSize="sm"
|
||||
minH={16}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
{!isOpen && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Box, FormControl } from '@chakra-ui/react';
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
@ -11,12 +11,15 @@ import {
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const promptInputSelector = createSelector(
|
||||
[(state: RootState) => state.generation, activeTabNameSelector],
|
||||
@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setPositivePrompt(e.target.value));
|
||||
};
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setPositivePrompt(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'alt+a',
|
||||
@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => {
|
||||
[]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setPositivePrompt(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, prompt]
|
||||
);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||
@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => {
|
||||
dispatch(clampSymmetrySteps());
|
||||
dispatch(userInvoked(activeTabName));
|
||||
}
|
||||
if (e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[dispatch, activeTabName, isReady]
|
||||
[isReady, dispatch, activeTabName, onOpen]
|
||||
);
|
||||
|
||||
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||
// const target = e.target as HTMLTextAreaElement;
|
||||
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||
// };
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<FormControl>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||
value={prompt}
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
ref={promptRef}
|
||||
minH={32}
|
||||
/>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={promptRef}
|
||||
value={prompt}
|
||||
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
minH={32}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
{!isOpen && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 6,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton, {
|
||||
IAIIconButtonProps,
|
||||
@ -25,26 +24,25 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
|
||||
};
|
||||
|
||||
return (
|
||||
<Tooltip label={t('common.pinOptionsPanel')}>
|
||||
<IAIIconButton
|
||||
{...props}
|
||||
aria-label={t('common.pinOptionsPanel')}
|
||||
onClick={handleClickPinOptionsPanel}
|
||||
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
sx={{
|
||||
color: 'base.700',
|
||||
_hover: {
|
||||
color: 'base.550',
|
||||
},
|
||||
_active: {
|
||||
color: 'base.500',
|
||||
},
|
||||
...sx,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
<IAIIconButton
|
||||
{...props}
|
||||
tooltip={t('common.pinOptionsPanel')}
|
||||
aria-label={t('common.pinOptionsPanel')}
|
||||
onClick={handleClickPinOptionsPanel}
|
||||
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
sx={{
|
||||
color: 'base.700',
|
||||
_hover: {
|
||||
color: 'base.550',
|
||||
},
|
||||
_active: {
|
||||
color: 'base.500',
|
||||
},
|
||||
...sx,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
@ -19,6 +19,7 @@ export const initialUIState: UIState = {
|
||||
shouldShowGallery: true,
|
||||
shouldHidePreview: false,
|
||||
shouldShowProgressInViewer: true,
|
||||
shouldShowEmbeddingPicker: false,
|
||||
favoriteSchedulers: [],
|
||||
};
|
||||
|
||||
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
|
||||
) => {
|
||||
state.favoriteSchedulers = action.payload;
|
||||
},
|
||||
toggleEmbeddingPicker: (state) => {
|
||||
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
@ -122,6 +126,7 @@ export const {
|
||||
toggleGalleryPanel,
|
||||
setShouldShowProgressInViewer,
|
||||
favoriteSchedulersChanged,
|
||||
toggleEmbeddingPicker,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
@ -27,5 +27,6 @@ export interface UIState {
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
shouldShowProgressInViewer: boolean;
|
||||
shouldShowEmbeddingPicker: boolean;
|
||||
favoriteSchedulers: SchedulerParam[];
|
||||
}
|
||||
|
@ -1,18 +1,18 @@
|
||||
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { io, Socket } from 'socket.io-client';
|
||||
import { Socket, io } from 'socket.io-client';
|
||||
|
||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import { getTimestamp } from 'common/util/getTimestamp';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import {
|
||||
ClientToServerEvents,
|
||||
ServerToClientEvents,
|
||||
} from 'services/events/types';
|
||||
import { socketSubscribed, socketUnsubscribed } from './actions';
|
||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import { getTimestamp } from 'common/util/getTimestamp';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
// import { OpenAPI } from 'services/api/types';
|
||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { $authToken, $baseUrl } from 'services/api/client';
|
||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||
|
||||
const socketioLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
@ -88,7 +88,7 @@ export const socketMiddleware = () => {
|
||||
socketSubscribed({
|
||||
sessionId: sessionId,
|
||||
timestamp: getTimestamp(),
|
||||
boardId: getState().boards.selectedBoardId,
|
||||
boardId: getState().gallery.selectedBoardId,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user