From 34c18dea90c169b14bad8c320c828b794ca8982a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 11 May 2023 23:11:07 -0500 Subject: [PATCH] Hook up more of the API --- src/lib/ComfyGraph.ts | 4 - src/lib/ComfyGraphCanvas.ts | 2 +- src/lib/api.ts | 232 ++++++++++++---------- src/lib/components/ComfyApp.ts | 81 +++++--- src/lib/components/ComfyQueue.svelte | 28 +-- src/lib/components/WidgetContainer.svelte | 4 +- src/lib/nodes/ComfyBackendNode.ts | 4 +- src/lib/nodes/ComfyGraphNode.ts | 4 +- src/lib/stores/queueState.ts | 181 +++++++++++++++-- src/lib/stores/uiState.ts | 23 ++- src/mobile/GenToolbar.svelte | 4 +- 11 files changed, 388 insertions(+), 179 deletions(-) diff --git a/src/lib/ComfyGraph.ts b/src/lib/ComfyGraph.ts index 46c1211..b0eaa55 100644 --- a/src/lib/ComfyGraph.ts +++ b/src/lib/ComfyGraph.ts @@ -24,10 +24,6 @@ type ComfyGraphEvents = { export default class ComfyGraph extends LGraph { eventBus: TypedEmitter = new EventEmitter() as TypedEmitter; - constructor() { - super(); - } - override onConfigure() { console.debug("Configured"); this.eventBus.emit("configured", this); diff --git a/src/lib/ComfyGraphCanvas.ts b/src/lib/ComfyGraphCanvas.ts index ae86b91..72ec18f 100644 --- a/src/lib/ComfyGraphCanvas.ts +++ b/src/lib/ComfyGraphCanvas.ts @@ -60,7 +60,7 @@ export default class ComfyGraphCanvas extends LGraphCanvas { let state = get(queueState); let color = null; - if (node.id === +state.runningNodeId) { + if (node.id === +state.runningNodeID) { color = "#0f0"; // this.app can be null inside the constructor if rendering is taking place already } else if (this.app && this.app.dragOverNode && node.id === this.app.dragOverNode.id) { diff --git a/src/lib/api.ts b/src/lib/api.ts index fa7fb35..d668fe8 100644 --- a/src/lib/api.ts +++ b/src/lib/api.ts @@ -1,3 +1,8 @@ +import type { Progress, SerializedPrompt, SerializedPromptOutput, SerializedPromptOutputs } from "./components/ComfyApp"; +import type TypedEmitter from "typed-emitter"; +import EventEmitter from "events"; +import type { GalleryOutput } from "./nodes/ComfyWidgetNodes"; + type PromptRequestBody = { client_id: string, prompt: any, @@ -8,27 +13,68 @@ type PromptRequestBody = { export type QueueItemType = "queue" | "history"; -export type ComfyAPIQueueStatus = { - exec_info: { - queue_remaining: number | "X"; - } +export type ComfyAPIStatusExecInfo = { + queueRemaining: number | "X"; } -export default class ComfyAPI extends EventTarget { - private registered: Set = new Set(); +export type ComfyAPIStatusResponse = { + execInfo?: ComfyAPIStatusExecInfo, + error?: string +} + +export type ComfyAPIQueueResponse = { + running: ComfyAPIHistoryItem[], + pending: ComfyAPIHistoryItem[], + error?: string +} + +export type NodeID = string; +export type PromptID = string; // UUID + +export type ComfyAPIHistoryItem = [ + number, // prompt number + PromptID, + SerializedPrompt, + any, // extra data + NodeID[] // good outputs +] + +export type ComfyAPIPromptResponse = { + promptID?: PromptID, + error?: string +} + +export type ComfyAPIHistoryEntry = { + prompt: ComfyAPIHistoryItem, + outputs: SerializedPromptOutputs +} + +export type ComfyAPIHistoryResponse = { + history: Record, + error?: string +} + +type ComfyAPIEvents = { + status: (status: ComfyAPIStatusResponse | null, error?: Error | null) => void, + progress: (progress: Progress) => void, + reconnecting: () => void, + reconnected: () => void, + executing: (promptID: PromptID | null, runningNodeID: NodeID | null) => void, + executed: (promptID: PromptID, nodeID: NodeID, output: SerializedPromptOutput) => void, + execution_cached: (promptID: PromptID, nodes: NodeID[]) => void, + execution_error: (promptID: PromptID, message: string) => void, +} + +export default class ComfyAPI { + private eventBus: TypedEmitter = new EventEmitter() as TypedEmitter; socket: WebSocket | null = null; clientId: string | null = null; hostname: string | null = null; port: number | null = 8188; - constructor() { - super(); - } - - override addEventListener(type: string, callback: EventListenerOrEventListenerObject | null, options?: AddEventListenerOptions | boolean) { - super.addEventListener(type, callback, options); - this.registered.add(type); + addEventListener(type: E, callback: ComfyAPIEvents[E]) { + this.eventBus.addListener(type, callback); } /** @@ -39,9 +85,9 @@ export default class ComfyAPI extends EventTarget { try { const resp = await fetch(this.getBackendUrl() + "/prompt"); const status = await resp.json(); - this.dispatchEvent(new CustomEvent("status", { detail: status })); + this.eventBus.emit("status", { execInfo: { queueRemaining: status.exec_info.queue_remaining } }); } catch (error) { - this.dispatchEvent(new CustomEvent("status", { detail: null })); + this.eventBus.emit("status", { error: error.toString() }); } }, 1000); } @@ -77,7 +123,7 @@ export default class ComfyAPI extends EventTarget { this.socket.addEventListener("open", () => { opened = true; if (isReconnect) { - this.dispatchEvent(new CustomEvent("reconnected")); + this.eventBus.emit("reconnected"); } }); @@ -94,8 +140,8 @@ export default class ComfyAPI extends EventTarget { this.createSocket(true); }, 300); if (opened) { - this.dispatchEvent(new CustomEvent("status", { detail: null })); - this.dispatchEvent(new CustomEvent("reconnecting")); + this.eventBus.emit("status", null); + this.eventBus.emit("reconnecting"); } }); @@ -108,29 +154,25 @@ export default class ComfyAPI extends EventTarget { this.clientId = msg.data.sid; sessionStorage["Comfy.SessionId"] = this.clientId; } - this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); + this.eventBus.emit("status", msg.data.status); break; case "progress": - this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); + this.eventBus.emit("progress", msg.data as Progress); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data })); + this.eventBus.emit("executing", msg.data.prompt_id, msg.data.node); break; case "executed": - this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); + this.eventBus.emit("executed", msg.data.prompt_id, msg.data.node, msg.data.output); break; case "execution_cached": - this.dispatchEvent(new CustomEvent("execution_cached", { detail: msg.data })); + this.eventBus.emit("execution_cached", msg.data.prompt_id, msg.data.nodes); break; case "execution_error": - this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + this.eventBus.emit("execution_error", msg.data.prompt_id, msg.data.message); break; default: - if (this.registered.has(msg.type)) { - this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); - } else { - throw new Error("Unknown message type"); - } + throw new Error(`Unknown message type: ${msg.type} ${msg}`); } } catch (error) { console.warn("Unhandled message:", event.data); @@ -149,27 +191,27 @@ export default class ComfyAPI extends EventTarget { * Gets a list of extension urls * @returns An array of script urls to import */ - async getExtensions() { - const resp = await fetch(this.getBackendUrl() + `/extensions`, { cache: "no-store" }); - return await resp.json(); + async getExtensions(): Promise { + return fetch(this.getBackendUrl() + `/extensions`, { cache: "no-store" }) + .then(resp => resp.json()) } /** * Gets a list of embedding names * @returns An array of script urls to import */ - async getEmbeddings() { - const resp = await fetch(this.getBackendUrl() + "/embeddings", { cache: "no-store" }); - return await resp.json(); + async getEmbeddings(): Promise { + return fetch(this.getBackendUrl() + "/embeddings", { cache: "no-store" }) + .then(resp => resp.json()) } /** * Loads node object definitions for the graph * @returns The node definitions */ - async getNodeDefs() { - const resp = await fetch(this.getBackendUrl() + "/object_info", { cache: "no-store" }); - return await resp.json(); + async getNodeDefs(): Promise { + return fetch(this.getBackendUrl() + "/object_info", { cache: "no-store" }) + .then(resp => resp.json()) } /** @@ -177,11 +219,11 @@ export default class ComfyAPI extends EventTarget { * @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue * @param {object} prompt The prompt data to queue */ - async queuePrompt(number: number, { output, workflow }) { + async queuePrompt(number: number, { output, workflow }, extra_data: any): Promise { const body: PromptRequestBody = { client_id: this.clientId, prompt: output, - extra_data: { extra_pnginfo: { workflow } }, + extra_data, front: false, number: number }; @@ -192,67 +234,52 @@ export default class ComfyAPI extends EventTarget { body.number = number; } - const res = await fetch(this.getBackendUrl() + "/prompt", { + let postBody = null; + try { + postBody = JSON.stringify(body) + } + catch (error) { + return Promise.reject({ error }) + } + + return fetch(this.getBackendUrl() + "/prompt", { method: "POST", headers: { "Content-Type": "application/json", }, - body: JSON.stringify(body), - }); - - if (res.status !== 200) { - throw { - response: await res.text(), - }; - } - } - - /** - * Loads a list of items (queue or history) - * @param {string} type The type of items to load, queue or history - * @returns The items of the specified type grouped by their status - */ - async getItems(type: QueueItemType) { - if (type === "queue") { - return this.getQueue(); - } - return this.getHistory(); + body: postBody + }) + .then(res => res.json()) + .then(raw => { return { promptID: raw.prompt_id } }) + .catch(res => { throw res.text() }) + .catch(error => { return { error } }) } /** * Gets the current state of the queue * @returns The currently running and queued items */ - async getQueue() { - try { - const res = await fetch(this.getBackendUrl() + "/queue"); - const data = await res.json(); - return { - // Running action uses a different endpoint for cancelling - Running: data.queue_running.map((prompt) => ({ - prompt, - remove: { name: "Cancel", cb: () => this.interrupt() }, - })), - Pending: data.queue_pending.map((prompt) => ({ prompt })), - }; - } catch (error) { - console.error(error); - return { Running: [], Pending: [], error }; - } + async getQueue(): Promise { + return fetch(this.getBackendUrl() + "/queue") + .then(res => res.json()) + .then(data => { + return { + running: data.queue_running, + pending: data.queue_pending, + } + }) + .catch(error => { return { running: [], pending: [], error } }) } /** * Gets the prompt execution history * @returns Prompt history including node outputs */ - async getHistory() { - try { - const res = await fetch(this.getBackendUrl() + "/history"); - return { History: Object.values(await res.json()) }; - } catch (error) { - console.error(error); - return { History: [], error }; - } + async getHistory(): Promise { + return fetch(this.getBackendUrl() + "/history") + .then(res => res.json()) + .then(history => { return { history } }) + .catch(error => { return { history: {}, error } }) } /** @@ -260,18 +287,21 @@ export default class ComfyAPI extends EventTarget { * @param {*} type The endpoint to post to * @param {*} body Optional POST data */ - private async postItem(type: string, body: any) { + private async postItem(type: QueueItemType, body: any): Promise { try { - await fetch("/" + type, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: body ? JSON.stringify(body) : undefined, - }); - } catch (error) { - console.error(error); + body = body ? JSON.stringify(body) : body } + catch (error) { + return Promise.reject(error) + } + + return fetch("/" + type, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: body + }); } /** @@ -279,22 +309,22 @@ export default class ComfyAPI extends EventTarget { * @param {string} type The type of item to delete, queue or history * @param {number} id The id of the item to delete */ - async deleteItem(type: string, id: number) { - await this.postItem(type, { delete: [id] }); + async deleteItem(type: QueueItemType, id: number): Promise { + return this.postItem(type, { delete: [id] }); } /** * Clears the specified list * @param {string} type The type of list to clear, queue or history */ - async clearItems(type: string) { - await this.postItem(type, { clear: true }); + async clearItems(type: QueueItemType): Promise { + return this.postItem(type, { clear: true }); } /** * Interrupts the execution of the running prompt */ - async interrupt() { - await this.postItem("interrupt", null); + async interrupt(): Promise { + return fetch("/interrupt", { method: "POST" }); } } diff --git a/src/lib/components/ComfyApp.ts b/src/lib/components/ComfyApp.ts index c42f227..087b98c 100644 --- a/src/lib/components/ComfyApp.ts +++ b/src/lib/components/ComfyApp.ts @@ -1,6 +1,6 @@ import { LiteGraph, LGraph, LGraphCanvas, LGraphNode, type LGraphNodeConstructor, type LGraphNodeExecutable, type SerializedLGraph, type SerializedLGraphGroup, type SerializedLGraphNode, type SerializedLLink, NodeMode, type Vector2, BuiltInSlotType, type INodeInputSlot } from "@litegraph-ts/core"; import type { LConnectionKind, INodeSlot } from "@litegraph-ts/core"; -import ComfyAPI, { type ComfyAPIQueueStatus } from "$lib/api" +import ComfyAPI, { type ComfyAPIStatusResponse, type NodeID, type PromptID } from "$lib/api" import { getPngMetadata, importA1111 } from "$lib/pnginfo"; import EventEmitter from "events"; import type TypedEmitter from "typed-emitter"; @@ -32,6 +32,7 @@ import { download, jsonToJsObject, promptToGraphVis, range, workflowToGraphVis } import notify from "$lib/notify"; import configState from "$lib/stores/configState"; import { blankGraph } from "$lib/defaultGraph"; +import type { GalleryOutput } from "$lib/nodes/ComfyWidgetNodes"; export const COMFYBOX_SERIAL_VERSION = 1; @@ -55,20 +56,22 @@ export type SerializedAppState = { } /** [link origin, link index] | value */ -export type SerializedPromptInput = [string, number] | any +export type SerializedPromptInput = [NodeID, number] | any export type SerializedPromptInputs = { - inputs: Record, + inputs: Record, class_type: string } -export type SerializedPromptOutput = Record +export type SerializedPromptInputsAll = Record export type SerializedPrompt = { workflow: SerializedLGraph, - output: SerializedPromptOutput + output: SerializedPromptInputsAll } +export type SerializedPromptOutputs = Record + export type Progress = { value: number, max: number @@ -176,6 +179,8 @@ export default class ComfyApp { this.addPasteHandler(); this.addKeyboardHandler(); + await this.updateHistoryAndQueue(); + // await this.#invokeExtensionsAsync("setup"); // Ensure the canvas fills the window @@ -319,47 +324,48 @@ export default class ComfyApp { * Handles updates from the API socket */ private addApiUpdateHandlers() { - this.api.addEventListener("status", ({ detail: ComfyAPIStatus }: CustomEvent) => { - // this.ui.setStatus(detail); + this.api.addEventListener("status", (status: ComfyAPIStatusResponse) => { + queueState.statusUpdated(status); }); this.api.addEventListener("reconnecting", () => { - // this.ui.dialog.show("Reconnecting..."); + uiState.reconnecting() }); this.api.addEventListener("reconnected", () => { - // this.ui.dialog.close(); + uiState.reconnected() }); - this.api.addEventListener("progress", ({ detail }: CustomEvent) => { - queueState.progressUpdated(detail); + this.api.addEventListener("progress", (progress: Progress) => { + queueState.progressUpdated(progress); this.lGraph.setDirtyCanvas(true, false); }); - this.api.addEventListener("executing", ({ detail }: CustomEvent) => { - queueState.executingUpdated(detail.node); + this.api.addEventListener("executing", (promptID: PromptID | null, nodeID: NodeID | null) => { + queueState.executingUpdated(promptID, nodeID); this.lGraph.setDirtyCanvas(true, false); }); - this.api.addEventListener("status", (ev: CustomEvent) => { - queueState.statusUpdated(ev.detail as ComfyAPIQueueStatus); + this.api.addEventListener("status", (status: ComfyAPIStatusResponse | null) => { + queueState.statusUpdated(status); }); - this.api.addEventListener("executed", ({ detail }: CustomEvent) => { - this.nodeOutputs[detail.node] = detail.output; - const node = this.lGraph.getNodeById(detail.node) as ComfyGraphNode; + this.api.addEventListener("executed", (promptID: PromptID, nodeID: NodeID, output: GalleryOutput) => { + this.nodeOutputs[nodeID] = output; + const node = this.lGraph.getNodeById(parseInt(nodeID)) as ComfyGraphNode; if (node?.onExecuted) { - node.onExecuted(detail.output); + node.onExecuted(output); } + queueState.onExecuted(promptID, nodeID, output) }); - this.api.addEventListener("execution_cached", ({ detail }: CustomEvent) => { - // TODO detail.nodes + this.api.addEventListener("execution_cached", (promptID: PromptID, nodes: NodeID[]) => { + queueState.executionCached(promptID, nodes) }); - this.api.addEventListener("execution_error", ({ detail }: CustomEvent) => { - queueState.update(s => { s.progress = null; s.runningNodeId = null; return s; }) - notify(`Execution error: ${detail.message}`, { type: "error", timeout: 10000 }) + this.api.addEventListener("execution_error", (promptID: PromptID, message: string) => { + queueState.executionError(promptID, message) + notify(`Execution error: ${message}`, { type: "error", timeout: 10000 }) }); this.api.init(); @@ -379,6 +385,13 @@ export default class ComfyApp { }); } + private async updateHistoryAndQueue() { + const queue = await this.api.getQueue(); + const history = await this.api.getHistory(); + console.warn("QUEUE", queue) + console.warn("HISTORY", history) + } + private requestPermissions() { if (Notification.permission === "default") { Notification.requestPermission() @@ -443,6 +456,8 @@ export default class ComfyApp { this.lGraph.start(); this.lGraph.eventBus.on("afterExecute", () => this.lCanvas.draw(true)) + + uiState.update(s => { s.uiUnlocked = this.lGraph._nodes.length === 0; return s; }) } async initDefaultGraph() { @@ -729,10 +744,20 @@ export default class ComfyApp { const p = await this.graphToPrompt(tag); console.debug(promptToGraphVis(p)) + const extra_data = { extra_pnginfo: { workflow: p.workflow } } + + let error = null; + let promptID = null; + try { - await this.api.queuePrompt(num, p); + const response = await this.api.queuePrompt(num, p, extra_data); + promptID = response.promptID; + error = response.error; } catch (error) { - // this.ui.dialog.show(error.response || error.toString()); + error = error.toString(); + } + + if (error != null) { const mes = error.response || error.toString() notify(`Error queuing prompt:\n${mes}`, { type: "error" }) console.error(promptToGraphVis(p)) @@ -748,7 +773,7 @@ export default class ComfyApp { } this.lCanvas.draw(true, true); - // await this.ui.queue.update(); + queueState.afterQueued(promptID, num, p, extra_data) } } } finally { @@ -767,7 +792,7 @@ export default class ComfyApp { if (pngInfo.comfyBoxConfig) { this.deserialize(JSON.parse(pngInfo.comfyBoxConfig)); } else if (pngInfo.parameters) { - throw "TODO import A111 import!" + throw "TODO A111 import!" // importA1111(this.lGraph, pngInfo.parameters, this.api); } else { diff --git a/src/lib/components/ComfyQueue.svelte b/src/lib/components/ComfyQueue.svelte index 181bf98..15b5d71 100644 --- a/src/lib/components/ComfyQueue.svelte +++ b/src/lib/components/ComfyQueue.svelte @@ -49,18 +49,18 @@ $: if (entries) { _entries = [] - // for (const entry of entries) { - // for (const outputs of Object.values(entry.outputs)) { - // const allImages = outputs.images.map(r => { - // // TODO configure backend URL - // const url = "http://localhost:8188/view?" - // const params = new URLSearchParams(r) - // return url + params - // }); - // - // _entries.push({ allImages, name: "Output" }) - // } - // } + for (const entry of entries) { + for (const outputs of Object.values(entry.outputs)) { + const allImages = outputs.images.map(r => { + // TODO configure backend URL + const url = "http://localhost:8188/view?" + const params = new URLSearchParams(r) + return url + params + }); + + _entries.push({ allImages, name: "Output" }) + } + } } @@ -76,9 +76,9 @@ {/each}
- {#if $queueState.runningNodeId || $queueState.progress} + {#if $queueState.runningNodeID || $queueState.progress}
- Node: {getNodeInfo($queueState.runningNodeId)} + Node: {getNodeInfo($queueState.runningNodeID)}
diff --git a/src/lib/components/WidgetContainer.svelte b/src/lib/components/WidgetContainer.svelte index 58d784e..cc13603 100644 --- a/src/lib/components/WidgetContainer.svelte +++ b/src/lib/components/WidgetContainer.svelte @@ -50,7 +50,7 @@ $: if ($queueState && widget && widget.node) { - dragItem.isNodeExecuting = $queueState.runningNodeId === widget.node.id; + dragItem.isNodeExecuting = $queueState.runningNodeID === widget.node.id; } function getWidgetClass() { @@ -72,7 +72,7 @@
diff --git a/src/lib/nodes/ComfyBackendNode.ts b/src/lib/nodes/ComfyBackendNode.ts index 4c1e3c3..c6039e6 100644 --- a/src/lib/nodes/ComfyBackendNode.ts +++ b/src/lib/nodes/ComfyBackendNode.ts @@ -1,7 +1,7 @@ import LGraphCanvas from "@litegraph-ts/core/src/LGraphCanvas"; import ComfyGraphNode from "./ComfyGraphNode"; import ComfyWidgets from "$lib/widgets" -import type { ComfyWidgetNode } from "./ComfyWidgetNodes"; +import type { ComfyWidgetNode, GalleryOutput } from "./ComfyWidgetNodes"; import { BuiltInSlotType, type SerializedLGraphNode } from "@litegraph-ts/core"; import type IComfyInputSlot from "$lib/IComfyInputSlot"; import type { ComfyInputConfig } from "$lib/IComfyInputSlot"; @@ -110,7 +110,7 @@ export class ComfyBackendNode extends ComfyGraphNode { } } - override onExecuted(outputData: any) { + override onExecuted(outputData: GalleryOutput) { console.warn("onExecuted outputs", outputData) this.triggerSlot(0, outputData) } diff --git a/src/lib/nodes/ComfyGraphNode.ts b/src/lib/nodes/ComfyGraphNode.ts index 9d8924b..1250735 100644 --- a/src/lib/nodes/ComfyGraphNode.ts +++ b/src/lib/nodes/ComfyGraphNode.ts @@ -3,7 +3,7 @@ import type { SerializedPrompt } from "$lib/components/ComfyApp"; import type ComfyWidget from "$lib/components/widgets/ComfyWidget"; import { LGraph, LGraphNode, LLink, LiteGraph, NodeMode, type INodeInputSlot, type SerializedLGraphNode, type Vector2, type INodeOutputSlot, LConnectionKind, type SlotType, LGraphCanvas, getStaticPropertyOnInstance, type PropertyLayout, type SlotLayout } from "@litegraph-ts/core"; import type { SvelteComponentDev } from "svelte/internal"; -import type { ComfyWidgetNode } from "./ComfyWidgetNodes"; +import type { ComfyWidgetNode, GalleryOutput } from "./ComfyWidgetNodes"; import type IComfyInputSlot from "$lib/IComfyInputSlot"; import uiState from "$lib/stores/uiState"; import { get } from "svelte/store"; @@ -48,7 +48,7 @@ export default class ComfyGraphNode extends LGraphNode { * Triggered when the backend sends a finished output back with this node's ID. * Valid for output nodes like SaveImage and PreviewImage. */ - onExecuted?(output: any): void; + onExecuted?(output: GalleryOutput): void; /* * Allows you to manually specify an auto-config for certain input slot diff --git a/src/lib/stores/queueState.ts b/src/lib/stores/queueState.ts index 8f3076f..8974708 100644 --- a/src/lib/stores/queueState.ts +++ b/src/lib/stores/queueState.ts @@ -1,5 +1,6 @@ -import type { ComfyAPIQueueStatus } from "$lib/api"; -import type { Progress } from "$lib/components/ComfyApp"; +import type { ComfyAPIHistoryItem, ComfyAPIQueueResponse, ComfyAPIStatusResponse, NodeID, PromptID } from "$lib/api"; +import type { Progress, SerializedPrompt, SerializedPromptOutputs } from "$lib/components/ComfyApp"; +import type { GalleryOutput } from "$lib/nodes/ComfyWidgetNodes"; import { writable, type Writable } from "svelte/store"; export type QueueItem = { @@ -7,48 +8,188 @@ export type QueueItem = { } type QueueStateOps = { - statusUpdated: (status: ComfyAPIQueueStatus | null) => void, - executingUpdated: (runningNodeId: string | null) => void, - progressUpdated: (progress: Progress | null) => void + queueUpdated: (queue: ComfyAPIQueueResponse) => void, + statusUpdated: (status: ComfyAPIStatusResponse | null) => void, + executingUpdated: (promptID: PromptID | null, runningNodeID: NodeID | null) => void, + executionCached: (promptID: PromptID, nodes: NodeID[]) => void, + executionError: (promptID: PromptID, message: string) => void, + progressUpdated: (progress: Progress) => void + afterQueued: (promptID: PromptID, number: number, prompt: SerializedPrompt, extraData: any) => void + onExecuted: (promptID: PromptID, nodeID: NodeID, output: GalleryOutput) => void +} + +export type QueueEntry = { + number: number, + promptID: PromptID, + prompt: SerializedPrompt, + extraData: any, + goodOutputs: NodeID[], + + // Collected while the prompt is still executing + outputs: SerializedPromptOutputs, +} + +export type CompletedQueueEntry = { + entry: QueueEntry, + type: "success" | "error" | "all_cached", + error?: string, } export type QueueState = { + queueRunning: QueueEntry[], + queuePending: QueueEntry[], + queueCompleted: CompletedQueueEntry[], queueRemaining: number | "X" | null; - runningNodeId: number | null; + runningNodeID: number | null; progress: Progress | null } type WritableQueueStateStore = Writable & QueueStateOps; -const store: Writable = writable({ queueRemaining: null, runningNodeId: null, progress: null }) +const store: Writable = writable({ + queueRunning: [], + queuePending: [], + queueCompleted: [], + queueRemaining: null, + runningNodeID: null, + progress: null +}) -function statusUpdated(status: ComfyAPIQueueStatus | null) { +function toQueueEntry(resp: ComfyAPIHistoryItem): QueueEntry { + const [num, promptID, prompt, extraData, goodOutputs] = resp + return { + number: num, + promptID, + prompt, + extraData, + goodOutputs, + outputs: {} + } +} + +function queueUpdated(queue: ComfyAPIQueueResponse) { store.update((s) => { - if (status !== null) - s.queueRemaining = status.exec_info.queue_remaining; + s.queueRunning = queue.running.map(toQueueEntry); + s.queuePending = queue.pending.map(toQueueEntry); + s.queueRemaining = s.queuePending.length; return s }) } -function executingUpdated(runningNodeId: string | null) { - store.update((s) => { - s.progress = null; - s.runningNodeId = parseInt(runningNodeId); - return s - }) -} - -function progressUpdated(progress: Progress | null) { +function progressUpdated(progress: Progress) { store.update((s) => { s.progress = progress; return s }) } +function statusUpdated(status: ComfyAPIStatusResponse | null) { + store.update((s) => { + if (status !== null) + s.queueRemaining = status.execInfo.queueRemaining; + return s + }) +} + +function executingUpdated(promptID: PromptID | null, runningNodeID: NodeID | null) { + console.debug("[queueState] executingUpdated", promptID, runningNodeID) + store.update((s) => { + s.progress = null; + if (runningNodeID != null) { + s.runningNodeID = parseInt(runningNodeID); + } + else if (promptID != null) { + // Prompt finished executing. + const index = s.queuePending.findIndex(e => e.promptID === promptID) + if (index) { + s.queuePending = s.queuePending.splice(index, 1); + } + s.progress = null; + s.runningNodeID = null; + } + return s + }) +} + +function executionCached(promptID: PromptID, nodes: NodeID[]) { + console.debug("[queueState] executionCached", promptID, nodes) + store.update(s => { + const index = s.queuePending.findIndex(e => e.promptID === promptID) + if (index) { + const entry = s.queuePending[index] + + if (nodes.length >= Object.keys(entry.prompt.output).length) { + s.queuePending = s.queuePending.splice(index, 1); + const completed: CompletedQueueEntry = { + entry, + type: "all_cached" + } + s.queueCompleted.push(completed) + } + } + s.progress = null; + s.runningNodeID = null; + return s + }) +} + +function executionError(promptID: PromptID, message: string) { + console.debug("[queueState] executionError", promptID, message) + store.update(s => { + const index = s.queuePending.findIndex(e => e.promptID === promptID) + if (index) { + const entry = s.queuePending[index] + s.queuePending = s.queuePending.splice(index, 1); + const completed: CompletedQueueEntry = { + entry, + type: "error", + error: message + } + s.queueCompleted.push(completed) + } + s.progress = null; + s.runningNodeID = null; + return s + }) +} + +function afterQueued(promptID: PromptID, number: number, prompt: SerializedPrompt, extraData: any) { + console.debug("[queueState] afterQueued", promptID, Object.keys(prompt.workflow.nodes)) + store.update(s => { + const entry: QueueEntry = { + number, + promptID, + prompt, + extraData, + goodOutputs: [], + outputs: {} + } + s.queuePending.push(entry) + return s + }) +} + +function onExecuted(promptID: PromptID, nodeID: NodeID, output: GalleryOutput) { + console.debug("[queueState] onExecuted", promptID, nodeID, output) + store.update(s => { + const entry = s.queuePending.find(e => e.promptID === promptID) + if (entry) { + entry.outputs[nodeID] = output; + s.queuePending.push(entry) + } + return s + }) +} + const queueStateStore: WritableQueueStateStore = { ...store, + queueUpdated, statusUpdated, + progressUpdated, executingUpdated, - progressUpdated + executionCached, + executionError, + afterQueued, + onExecuted } export default queueStateStore; diff --git a/src/lib/stores/uiState.ts b/src/lib/stores/uiState.ts index 332fb10..c827741 100644 --- a/src/lib/stores/uiState.ts +++ b/src/lib/stores/uiState.ts @@ -10,11 +10,17 @@ export type UIState = { uiUnlocked: boolean, uiEditMode: UIEditMode, + reconnecting: boolean, isSavingToLocalStorage: boolean } -export type WritableUIStateStore = Writable; -const store: WritableUIStateStore = writable( +type UIStateOps = { + reconnecting: () => void, + reconnected: () => void, +} + +export type WritableUIStateStore = Writable & UIStateOps; +const store: Writable = writable( { graphLocked: false, nodesLocked: false, @@ -22,11 +28,22 @@ const store: WritableUIStateStore = writable( uiUnlocked: false, uiEditMode: "widgets", + reconnecting: false, isSavingToLocalStorage: false }) +function reconnecting() { + store.update(s => { s.reconnecting = true; return s; }) +} + +function reconnected() { + store.update(s => { s.reconnecting = false; return s; }) +} + const uiStateStore: WritableUIStateStore = { - ...store + ...store, + reconnecting, + reconnected } export default uiStateStore; diff --git a/src/mobile/GenToolbar.svelte b/src/mobile/GenToolbar.svelte index a2be969..fb2c9ac 100644 --- a/src/mobile/GenToolbar.svelte +++ b/src/mobile/GenToolbar.svelte @@ -57,9 +57,9 @@
- {#if $queueState.runningNodeId || $queueState.progress} + {#if $queueState.runningNodeID || $queueState.progress}
- Node: {getNodeInfo($queueState.runningNodeId)} + Node: {getNodeInfo($queueState.runningNodeID)}