use modelloader for #4956

This commit is contained in:
AUTOMATIC 2022-12-03 18:45:51 +03:00
parent 2a649154ec
commit 4b0dc206ed

View File

@ -1,4 +1,3 @@
import contextlib
import os import os
import sys import sys
import traceback import traceback
@ -11,12 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_local_dir = os.path.join('models', 'Interrogator')
blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"]) Category = namedtuple("Category", ["name", "topn", "items"])
@ -49,16 +45,14 @@ class InterrogateModels:
def load_blip_model(self): def load_blip_model(self):
import models.blip import models.blip
if not os.path.isfile(blip_local_file): files = modelloader.load_models(
if not os.path.isdir(blip_local_dir): model_path=os.path.join(paths.models_path, "BLIP"),
os.mkdir(blip_local_dir) model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
ext_filter=[".pth"],
download_name='model_base_caption_capfilt_large.pth',
)
print("Downloading BLIP...") blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
from requests import get as reqget
open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
print("BLIP downloaded to", blip_local_file + '.')
blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval() blip_model.eval()
return blip_model return blip_model