Start restore parameters

This commit is contained in:
space-nuko
2023-05-29 00:16:02 -05:00
parent fde480cb43
commit 26ab7989c8
10 changed files with 485 additions and 110 deletions

View File

@@ -1,31 +1,88 @@
import type { ComfyBoxStdGroupLoRA, ComfyBoxStdPrompt } from "$lib/ComfyBoxStdPrompt";
import type { SerializedPrompt, SerializedPromptInputs } from "./components/ComfyApp";
import StdPrompt from "$lib/ComfyBoxStdPrompt";
import type { SafeParseReturnType, ZodError } from "zod";
import type { ComfyNodeID } from "./api";
import type { SerializedAppState, SerializedPrompt, SerializedPromptInputs, SerializedPromptInputsAll } from "./components/ComfyApp";
import { ComfyComboNode, type ComfyWidgetNode } from "./nodes/widgets";
import { basename, isSerializedPromptInputLink } from "./utils";
export type ComfyPromptConverter = (stdPrompt: ComfyBoxStdPrompt, inputs: SerializedPromptInputs, nodeID: ComfyNodeID) => void;
function LoraLoader(stdPrompt: ComfyBoxStdPrompt, inputs: SerializedPromptInputs) {
const params = stdPrompt.parameters
const lora: ComfyBoxStdGroupLoRA = {
model_name: inputs["lora_name"],
strength_unet: inputs["strength_model"],
strength_tenc: inputs["strength_clip"]
}
if (params.lora)
params.lora.push(lora)
else
params.lora = [lora]
export type ComfyPromptConverter = {
encoder: ComfyPromptEncoder,
decoder: ComfyPromptDecoder
}
const ALL_CONVERTERS: Record<string, ComfyPromptConverter> = {
LoraLoader
//
export type ComfyDecodeArgument = {
groupName: string,
keyName: string,
value: any,
widgetNode: ComfyWidgetNode
};
export type ComfyPromptEncoder = (stdPrompt: ComfyBoxStdPrompt, inputs: SerializedPromptInputs, nodeID: ComfyNodeID) => void;
export type ComfyPromptDecoder = (args: ComfyDecodeArgument[]) => void;
const LoraLoader: ComfyPromptConverter = {
encoder: (stdPrompt: ComfyBoxStdPrompt, inputs: SerializedPromptInputs) => {
const params = stdPrompt.parameters
const loras: ComfyBoxStdGroupLoRA[] = params.lora
for (const lora of loras) {
lora.model_hashes = {
addnet_shorthash: null // TODO find hashes for model!
}
}
},
decoder: (args: ComfyDecodeArgument[]) => {
// Find corresponding model names in the ComfyUI models folder from the model base filename
for (const arg of args) {
if (arg.groupName === "lora" && arg.keyName === "model_name" && arg.widgetNode.is(ComfyComboNode)) {
const modelBasename = basename(arg.value);
const found = arg.widgetNode.properties.values.find(k => k.indexOf(modelBasename) !== -1)
if (found)
arg.value = found;
}
}
}
}
// input name -> group/key in standard prompt
type ComfyStdPromptMapping = Record<string, string>
type ComfyStdPromptSpec = {
paramMapping: ComfyStdPromptMapping,
extraParams?: Record<string, string>,
converter?: ComfyPromptConverter,
}
const ALL_SPECS: Record<string, ComfyStdPromptSpec> = {
"KSampler": {
paramMapping: {
cfg: "k_sampler.cfg_scale",
seed: "k_sampler.seed",
steps: "k_sampler.steps",
sampler_name: "k_sampler.sampler_name",
scheduler: "k_sampler.scheduler",
denoise: "k_sampler.denoise",
},
},
"LoraLoader": {
paramMapping: {
lora_name: "lora.model_name",
strength_model: "lora.strength_unet",
strength_clip: "lora.strength_tenc",
},
extraParams: {
"lora.module_name": "LoRA",
},
converter: LoraLoader,
}
}
const COMMIT_HASH: string = __GIT_COMMIT_HASH__;
export default class ComfyBoxStdPromptSerializer {
serialize(prompt: SerializedPrompt): ComfyBoxStdPrompt {
serialize(prompt: SerializedPromptInputsAll, workflow?: SerializedAppState): [SafeParseReturnType<any, ComfyBoxStdPrompt>, any] {
const stdPrompt: ComfyBoxStdPrompt = {
version: 1,
metadata: {
@@ -33,23 +90,57 @@ export default class ComfyBoxStdPromptSerializer {
commit_hash: COMMIT_HASH,
extra_data: {
comfybox: {
workflows: [] // TODO!!!
}
}
},
parameters: {}
}
for (const [nodeID, inputs] of Object.entries(prompt.output)) {
for (const [nodeID, inputs] of Object.entries(prompt)) {
const classType = inputs.class_type
const converter = ALL_CONVERTERS[classType]
if (converter) {
converter(stdPrompt, inputs.inputs, nodeID)
const spec = ALL_SPECS[classType]
if (spec) {
console.warn("SPEC", spec, inputs)
let targets = {}
for (const [comfyKey, stdPromptKey] of Object.entries(spec.paramMapping)) {
const inputValue = inputs.inputs[comfyKey];
if (inputValue != null && !isSerializedPromptInputLink(inputValue)) {
console.warn("GET", comfyKey, inputValue)
const trail = stdPromptKey.split(".");
let target = null;
console.warn(trail, trail.length - 2);
for (let index = 0; index < trail.length - 1; index++) {
const name = trail[index];
if (index === 0) {
targets[name] ||= {}
target = targets[name]
}
else {
target = target[name]
}
console.warn(index, name, target)
}
let name = trail[trail.length - 1]
target[name] = inputValue
console.warn(stdPrompt.parameters)
}
}
// TODO converter.encode
for (const [groupName, group] of Object.entries(targets)) {
stdPrompt.parameters[groupName] ||= []
stdPrompt.parameters[groupName].push(group)
}
}
else {
console.warn("No StdPrompt type converter for comfy class!", classType)
console.warn("No StdPrompt type spec for comfy class!", classType)
}
}
return stdPrompt
return [StdPrompt.safeParse(stdPrompt), stdPrompt];
}
}