Merge pull request #4956 from TiagoSantos81/offline_BLIP

[CLIP interrogator] use local file, if available
This commit is contained in:
AUTOMATIC1111 2022-12-03 18:17:56 +03:00 committed by GitHub
commit 2a649154ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,6 +14,8 @@ import modules.shared as shared
from modules import devices, paths, lowvram from modules import devices, paths, lowvram
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_local_dir = os.path.join('models', 'Interrogator')
blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
@ -47,7 +49,16 @@ class InterrogateModels:
def load_blip_model(self): def load_blip_model(self):
import models.blip import models.blip
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) if not os.path.isfile(blip_local_file):
if not os.path.isdir(blip_local_dir):
os.mkdir(blip_local_dir)
print("Downloading BLIP...")
from requests import get as reqget
open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
print("BLIP downloaded to", blip_local_file + '.')
blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval() blip_model.eval()
return blip_model return blip_model