diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index bbe372ff57..c11ebd3f56 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -108,14 +108,15 @@ class CompelInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback @@ -196,14 +197,15 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback @@ -270,14 +272,15 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 2bec128b87..dec5b939a0 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation): **self.clip.text_encoder.dict(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] loras = [ (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras @@ -75,20 +74,14 @@ class ONNXPromptInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append( - # stack.enter_context( - # context.services.model_manager.get_model( - # model_name=name, - # base_model=self.clip.text_encoder.base_model, - # model_type=ModelType.TextualInversion, - # ) - # ) + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, ).context.model - ) + )) except Exception: # print(e) # import traceback diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 7ccf5e57ae..e8e2b3f51f 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -164,7 +164,7 @@ class ModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Any], + ti_list: List[Tuple[str, Any]], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: init_tokens_count = None new_tokens_added = None @@ -174,27 +174,27 @@ class ModelPatcher: ti_manager = TextualInversionManager(ti_tokenizer) init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings - def _get_trigger(ti, index): - trigger = ti.name + def _get_trigger(ti_name, index): + trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 - for ti in ti_list: + for ti_name, ti in ti_list: for i in range(ti.embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) model_embeddings = text_encoder.get_input_embeddings() - for ti in ti_list: + for ti_name, ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i] - trigger = _get_trigger(ti, i) + trigger = _get_trigger(ti_name, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: @@ -239,7 +239,6 @@ class ModelPatcher: class TextualInversionModel: - name: str embedding: torch.Tensor # [n, 768]|[n, 1280] @classmethod @@ -253,7 +252,6 @@ class TextualInversionModel: file_path = Path(file_path) result = cls() # TODO: - result.name = file_path.stem # TODO: if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") @@ -430,7 +428,7 @@ class ONNXModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, - ti_list: List[Any], + ti_list: List[Tuple[str, Any]], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: from .models.base import IAIOnnxRuntimeModel @@ -443,17 +441,17 @@ class ONNXModelPatcher: ti_tokenizer = copy.deepcopy(tokenizer) ti_manager = TextualInversionManager(ti_tokenizer) - def _get_trigger(ti, index): - trigger = ti.name + def _get_trigger(ti_name, index): + trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 - for ti in ti_list: + for ti_name, ti in ti_list: for i in range(ti.embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] @@ -463,11 +461,11 @@ class ONNXModelPatcher: axis=0, ) - for ti in ti_list: + for ti_name, ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i].detach().numpy() - trigger = _get_trigger(ti, i) + trigger = _get_trigger(ti_name, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: