mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix conflict resolve, add model configs to type annotation
This commit is contained in:
parent
c9ae26a176
commit
26090011c4
@ -7,6 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.model_management.models import get_all_model_configs
|
||||||
|
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -60,7 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
|
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
|
||||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import einops
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel
|
from diffusers import ControlNetModel
|
||||||
@ -235,11 +237,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||||
configure_model_padding(
|
# TODO:
|
||||||
unet,
|
#configure_model_padding(
|
||||||
self.seamless,
|
# unet,
|
||||||
self.seamless_axes,
|
# self.seamless,
|
||||||
)
|
# self.seamless_axes,
|
||||||
|
#)
|
||||||
|
|
||||||
class FakeVae:
|
class FakeVae:
|
||||||
class FakeVaeConfig:
|
class FakeVaeConfig:
|
||||||
@ -261,13 +264,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||||
)
|
)
|
||||||
|
|
||||||
def prep_control_data(self,
|
def prep_control_data(
|
||||||
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||||
control_input: List[ControlField],
|
control_input: List[ControlField],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
|
|
||||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * 8
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * 8
|
||||||
@ -362,10 +367,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
control_data = self.prep_control_data(
|
||||||
|
model=pipeline, context=context, control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,)
|
do_classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
@ -434,7 +441,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||||
|
|
||||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
control_data = self.prep_control_data(
|
||||||
|
model=pipeline, context=context, control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
@ -33,10 +33,12 @@ MODEL_CLASSES = {
|
|||||||
#},
|
#},
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: check with openapi annotation
|
|
||||||
def get_all_model_configs():
|
def get_all_model_configs():
|
||||||
configs = []
|
configs = set()
|
||||||
for models in MODEL_CLASSES.values():
|
for models in MODEL_CLASSES.values():
|
||||||
for model in models.values():
|
for type, model in models.items():
|
||||||
configs.extend(model._get_configs())
|
if type == ModelType.ControlNet:
|
||||||
return configs
|
continue # TODO:
|
||||||
|
configs.update(model._get_configs().values())
|
||||||
|
configs.discard(None)
|
||||||
|
return list(configs) # TODO: set, list or tuple
|
||||||
|
Loading…
Reference in New Issue
Block a user