sort embeddings by name (case insensitive)

This commit is contained in:
Brad Smith 2023-04-08 15:58:00 -04:00
parent 22bcc7be42
commit 27b9ec60e4
No known key found for this signature in database
GPG Key ID: CDABCFFBBD8DA710

View File

@ -2,7 +2,7 @@ import os
import sys import sys
import traceback import traceback
import inspect import inspect
from collections import namedtuple from collections import namedtuple, OrderedDict
import torch import torch
import tqdm import tqdm
@ -108,7 +108,7 @@ class DirWithTextualInversionEmbeddings:
class EmbeddingDatabase: class EmbeddingDatabase:
def __init__(self): def __init__(self):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = {} self.word_embeddings = OrderedDict()
self.skipped_embeddings = {} self.skipped_embeddings = {}
self.expected_shape = -1 self.expected_shape = -1
self.embedding_dirs = {} self.embedding_dirs = {}
@ -233,6 +233,9 @@ class EmbeddingDatabase:
self.load_from_dir(embdir) self.load_from_dir(embdir)
embdir.update() embdir.update()
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
self.word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings: if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings self.previously_displayed_embeddings = displayed_embeddings