Add api method to get LoRA models with prompt

This commit is contained in:
Sayo 2023-05-08 20:38:10 +08:00
parent 34a82a345a
commit f9abe4cddc
2 changed files with 34 additions and 10 deletions

View File

@ -2,9 +2,8 @@ import glob
import os
import re
import torch
from typing import Union, List, Optional
from fastapi import FastAPI
import gradio as gr
from typing import Union
import scripts.api as api
from modules import shared, devices, sd_models, errors, scripts
@ -445,12 +444,6 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
def api(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def getloras():
return [{"name": name, "path": available_loras[name].filename, "prompt": ""} for name in available_loras]
available_loras = {}
available_lora_aliases = {}
loaded_loras = []
@ -458,6 +451,6 @@ loaded_loras = []
list_available_loras()
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(api)
script_callbacks.on_app_started(api.api)
except:
pass

View File

@ -0,0 +1,31 @@
from fastapi import FastAPI
import gradio as gr
import json
import os
import lora
def get_lora_prompts(path):
directory, filename = os.path.split(path)
name_without_ext = os.path.splitext(filename)[0]
new_filename = name_without_ext + '.civitai.info'
try:
new_path = os.path.join(directory, new_filename)
if os.path.exists(new_path):
with open(new_path, 'r') as f:
data = json.load(f)
trained_words = data.get('trainedWords', [])
if len(trained_words) > 0:
result = ','.join(trained_words)
return result
else:
return ''
else:
return ''
except Exception as e:
return ''
def api(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [{"name": name, "path": lora.available_loras[name].filename, "prompt": get_lora_prompts(lora.available_loras[name].filename)} for name in lora.available_loras]