diff --git a/modules/sd_models.py b/modules/sd_models.py index 4bd8783e..8e42bfea 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -409,12 +409,16 @@ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_w class SdModelData: def __init__(self): self.sd_model = None + self.was_loaded_at_least_once = False self.lock = threading.Lock() def get_sd_model(self): + if self.was_loaded_at_least_once: + return self.sd_model + if self.sd_model is None: with self.lock: - if self.sd_model is not None: + if self.sd_model is not None or self.was_loaded_at_least_once: return self.sd_model try: @@ -495,6 +499,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model.eval() model_data.sd_model = sd_model + model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model