From 4635f31270d1b5d41ad63815cb400b1ca73ea859 Mon Sep 17 00:00:00 2001 From: klimaleksus Date: Mon, 29 May 2023 01:09:59 +0500 Subject: [PATCH] Refactor EmbeddingDatabase.register_embedding() to allow unregistering --- .../textual_inversion/textual_inversion.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e..cbf94498 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -120,16 +120,29 @@ class EmbeddingDatabase: self.embedding_dirs.clear() def register_embedding(self, embedding, model): - self.word_embeddings[embedding.name] = embedding - - ids = model.cond_stage_model.tokenize([embedding.name])[0] + return self.register_embedding_by_name(embedding, model, embedding.name) + def register_embedding_by_name(self, embedding, model, name): + ids = model.cond_stage_model.tokenize([name])[0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] - - self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) - + if name in self.word_embeddings: + # remove old one from the lookup list + lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name] + else: + lookup = self.ids_lookup[first_id] + if embedding is not None: + lookup += [(ids, embedding)] + self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) + if embedding is None: + # unregister embedding with specified name + if name in self.word_embeddings: + del self.word_embeddings[name] + if len(self.ids_lookup[first_id])==0: + del self.ids_lookup[first_id] + return None + self.word_embeddings[name] = embedding return embedding def get_expected_shape(self):