Fix conflict resolve, add model configs to type annotation

This commit is contained in:
Sergey Borisov 2023-06-14 00:26:37 +03:00
parent c9ae26a176
commit 26090011c4
3 changed files with 39 additions and 27 deletions

View File

@ -7,6 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -60,7 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] #models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]

View File

@ -3,6 +3,8 @@
from contextlib import ExitStack from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel
@ -235,11 +237,12 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
configure_model_padding( # TODO:
unet, #configure_model_padding(
self.seamless, # unet,
self.seamless_axes, # self.seamless,
) # self.seamless_axes,
#)
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
@ -261,13 +264,15 @@ class TextToLatentsInvocation(BaseInvocation):
precision="float16" if unet.dtype == torch.float16 else "float32", precision="float16" if unet.dtype == torch.float16 else "float32",
) )
def prep_control_data(self, def prep_control_data(
self,
context: InvocationContext, context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents # assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * 8
@ -362,10 +367,12 @@ class TextToLatentsInvocation(BaseInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(model=model, context=context, control_input=self.control, control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,) do_classifier_free_guidance=True,
)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -434,7 +441,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control, control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,

View File

@ -33,10 +33,12 @@ MODEL_CLASSES = {
#}, #},
} }
# TODO: check with openapi annotation
def get_all_model_configs(): def get_all_model_configs():
configs = [] configs = set()
for models in MODEL_CLASSES.values(): for models in MODEL_CLASSES.values():
for model in models.values(): for type, model in models.items():
configs.extend(model._get_configs()) if type == ModelType.ControlNet:
return configs continue # TODO:
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs) # TODO: set, list or tuple