From 2536fbbe27b4cbe79e96daa09824337fce7c9311 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 7 Apr 2023 08:15:29 -0500 Subject: [PATCH] Port some ComfyUI features --- src/lib/ComfyGraphCanvas.ts | 174 +++++++++++++ src/lib/ComfyGraphNode.ts | 5 + src/lib/api.ts | 462 ++++++++++++++++----------------- src/lib/components/ComfyApp.ts | 110 ++++++-- 4 files changed, 503 insertions(+), 248 deletions(-) create mode 100644 src/lib/ComfyGraphCanvas.ts create mode 100644 src/lib/ComfyGraphNode.ts diff --git a/src/lib/ComfyGraphCanvas.ts b/src/lib/ComfyGraphCanvas.ts new file mode 100644 index 0000000..bb92061 --- /dev/null +++ b/src/lib/ComfyGraphCanvas.ts @@ -0,0 +1,174 @@ +import { BuiltInSlotShape, LGraph, LGraphCanvas, LGraphNode, LiteGraph, NodeMode, type MouseEventExt, type Vector2, type Vector4 } from "@litegraph-ts/core"; +import type ComfyApp from "./components/ComfyApp"; + +export default class ComfyGraphCanvas extends LGraphCanvas { + app: ComfyApp + + constructor( + app: ComfyApp, + canvas: HTMLCanvasElement | string, + graph?: LGraph, + options: { + skip_render?: boolean; + skip_events?: boolean; + autoresize?: boolean; + viewport?: Vector4; + } = {} + ) { + super(canvas, graph, options); + this.app = app; + } + + override drawNodeShape( + node: LGraphNode, + ctx: CanvasRenderingContext2D, + size: Vector2, + fgColor: string, + bgColor: string, + selected: boolean, + mouseOver: boolean + ): void { + super.drawNodeShape(node, ctx, size, fgColor, bgColor, selected, mouseOver); + + let color = null; + if (node.id === +this.app.runningNodeId) { + color = "#0f0"; + } else if (this.app.dragOverNode && node.id === this.app.dragOverNode.id) { + color = "dodgerblue"; + } + + if (color) { + const shape = node.shape || BuiltInSlotShape.ROUND_SHAPE; + ctx.lineWidth = 1; + ctx.globalAlpha = 0.8; + ctx.beginPath(); + if (shape == BuiltInSlotShape.BOX_SHAPE) + ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); + else if (shape == BuiltInSlotShape.ROUND_SHAPE || (shape == BuiltInSlotShape.CARD_SHAPE && node.flags.collapsed)) + ctx.roundRect( + -6, + -6 - LiteGraph.NODE_TITLE_HEIGHT, + 12 + size[0] + 1, + 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, + this.round_radius * 2 + ); + else if (shape == BuiltInSlotShape.CARD_SHAPE) + ctx.roundRect( + -6, + -6 + LiteGraph.NODE_TITLE_HEIGHT, + 12 + size[0] + 1, + 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, + this.round_radius * 2, + 2 + ); + else if (shape == BuiltInSlotShape.CIRCLE_SHAPE) + ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); + ctx.strokeStyle = color; + ctx.stroke(); + ctx.strokeStyle = fgColor; + ctx.globalAlpha = 1; + + if (this.app.progress) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (this.app.progress.value / this.app.progress.max), 6); + ctx.fillStyle = bgColor; + } + } + } + + override drawNode(node: LGraphNode, ctx: CanvasRenderingContext2D): void { + var editor_alpha = this.editor_alpha; + if (node.mode === NodeMode.NEVER) { // never + this.editor_alpha = 0.4; + } + const res = super.drawNode(node, ctx); + this.editor_alpha = editor_alpha; + + return res; + } + + override drawGroups(canvas: HTMLCanvasElement, ctx: CanvasRenderingContext2D) { + if (!this.graph) { + return; + } + + var groups = this.graph._groups; + + ctx.save(); + ctx.globalAlpha = 0.7 * this.editor_alpha; + + for (var i = 0; i < groups.length; ++i) { + var group = groups[i]; + + if (!LiteGraph.overlapBounding(this.visible_area, group.bounding)) { + continue; + } //out of the visible area + + ctx.fillStyle = group.color || "#335"; + ctx.strokeStyle = group.color || "#335"; + var pos = group.pos; + var size = group.size; + ctx.globalAlpha = 0.25 * this.editor_alpha; + ctx.beginPath(); + var font_size = + group.fontSize || LiteGraph.DEFAULT_GROUP_FONT_SIZE; + ctx.rect(pos[0] + 0.5, pos[1] + 0.5, size[0], font_size * 1.4); + ctx.fill(); + ctx.globalAlpha = this.editor_alpha; + } + + ctx.restore(); + + const res = super.drawGroups(canvas, ctx); + return res; + } + + /** + * Handle keypress + * + * Ctrl + M mute/unmute selected nodes + */ + override processKey(e: KeyboardEvent): boolean | undefined { + const res = super.processKey(e); + + if (res === false) { + return res; + } + + if (!this.graph) { + return; + } + + var block_default = false; + + if ("localName" in e.target && e.target.localName == "input") { + return; + } + + if (e.type == "keydown") { + // Ctrl + M mute/unmute + if (e.keyCode == 77 && e.ctrlKey) { + if (this.selected_nodes) { + for (var i in this.selected_nodes) { + if (this.selected_nodes[i].mode === 2) { // never + this.selected_nodes[i].mode = 0; // always + } else { + this.selected_nodes[i].mode = 2; // never + } + } + } + block_default = true; + } + } + + this.graph.change(); + + if (block_default) { + e.preventDefault(); + e.stopImmediatePropagation(); + return false; + } + + return res; + } +} diff --git a/src/lib/ComfyGraphNode.ts b/src/lib/ComfyGraphNode.ts new file mode 100644 index 0000000..cb6939e --- /dev/null +++ b/src/lib/ComfyGraphNode.ts @@ -0,0 +1,5 @@ +import { LGraphNode } from "@litegraph-ts/core"; + +export default class ComfyGraphNode extends LGraphNode { + onExecuted?(output: any): void; +} diff --git a/src/lib/api.ts b/src/lib/api.ts index 31bc5bb..9da9ca8 100644 --- a/src/lib/api.ts +++ b/src/lib/api.ts @@ -21,24 +21,24 @@ export default class ComfyAPI extends EventTarget { } override addEventListener(type: string, callback: EventListenerOrEventListenerObject | null, options?: AddEventListenerOptions | boolean) { - super.addEventListener(type, callback, options); - this.registered.add(type); - } + super.addEventListener(type, callback, options); + this.registered.add(type); + } - /** - * Poll status for colab and other things that don't support websockets. - */ - private pollQueue() { - setInterval(async () => { - try { - const resp = await fetch(this.getBackendUrl() + "/prompt"); - const status = await resp.json(); - this.dispatchEvent(new CustomEvent("status", { detail: status })); - } catch (error) { - this.dispatchEvent(new CustomEvent("status", { detail: null })); - } - }, 1000); - } + /** + * Poll status for colab and other things that don't support websockets. + */ + private pollQueue() { + setInterval(async () => { + try { + const resp = await fetch(this.getBackendUrl() + "/prompt"); + const status = await resp.json(); + this.dispatchEvent(new CustomEvent("status", { detail: status })); + } catch (error) { + this.dispatchEvent(new CustomEvent("status", { detail: null })); + } + }, 1000); + } private getBackendUrl(): string { const hostname = this.hostname || location.hostname; @@ -48,243 +48,243 @@ export default class ComfyAPI extends EventTarget { return `${window.location.protocol}//${hostname}:${port}` } - /** - * Creates and connects a WebSocket for realtime updates - * @param {boolean} isReconnect If the socket is connection is a reconnect attempt - */ - private createSocket(isReconnect: boolean = false) { - if (this.socket) { - return; - } + /** + * Creates and connects a WebSocket for realtime updates + * @param {boolean} isReconnect If the socket is connection is a reconnect attempt + */ + private createSocket(isReconnect: boolean = false) { + if (this.socket) { + return; + } - let opened = false; - let existingSession = sessionStorage["Comfy.SessionId"] || ""; - if (existingSession) { - existingSession = "/" + existingSession; - } + let opened = false; + let existingSession = sessionStorage["Comfy.SessionId"] || ""; + if (existingSession) { + existingSession = "/" + existingSession; + } - const hostname = this.hostname || location.host; + const hostname = this.hostname || location.hostname; const port = this.port || location.port; - this.socket = new WebSocket( - `ws${window.location.protocol === "https:" ? "s" : ""}://${hostname}:${port}/ws${existingSession}` - ); + this.socket = new WebSocket( + `ws${window.location.protocol === "https:" ? "s" : ""}://${hostname}:${port}/ws?clientId=${existingSession}` + ); - this.socket.addEventListener("open", () => { - opened = true; - if (isReconnect) { - this.dispatchEvent(new CustomEvent("reconnected")); - } - }); + this.socket.addEventListener("open", () => { + opened = true; + if (isReconnect) { + this.dispatchEvent(new CustomEvent("reconnected")); + } + }); - this.socket.addEventListener("error", () => { - if (this.socket) this.socket.close(); - if (!isReconnect && !opened) { - this.pollQueue(); - } - }); + this.socket.addEventListener("error", () => { + if (this.socket) this.socket.close(); + if (!isReconnect && !opened) { + this.pollQueue(); + } + }); - this.socket.addEventListener("close", () => { - setTimeout(() => { - this.socket = null; - this.createSocket(true); - }, 300); - if (opened) { - this.dispatchEvent(new CustomEvent("status", { detail: null })); - this.dispatchEvent(new CustomEvent("reconnecting")); - } - }); + this.socket.addEventListener("close", () => { + setTimeout(() => { + this.socket = null; + this.createSocket(true); + }, 300); + if (opened) { + this.dispatchEvent(new CustomEvent("status", { detail: null })); + this.dispatchEvent(new CustomEvent("reconnecting")); + } + }); - this.socket.addEventListener("message", (event) => { - try { - const msg = JSON.parse(event.data); - switch (msg.type) { - case "status": - if (msg.data.sid) { - this.clientId = msg.data.sid; - sessionStorage["Comfy.SessionId"] = this.clientId; - } - this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); - break; - case "progress": - this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); - break; - case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); - break; - case "executed": - this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); - break; - default: - if (this.registered.has(msg.type)) { - this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); - } else { - throw new Error("Unknown message type"); - } - } - } catch (error) { - console.warn("Unhandled message:", event.data); - } - }); - } + this.socket.addEventListener("message", (event) => { + try { + const msg = JSON.parse(event.data); + switch (msg.type) { + case "status": + if (msg.data.sid) { + this.clientId = msg.data.sid; + sessionStorage["Comfy.SessionId"] = this.clientId; + } + this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); + break; + case "progress": + this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); + break; + case "executing": + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + break; + case "executed": + this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); + break; + default: + if (this.registered.has(msg.type)) { + this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); + } else { + throw new Error("Unknown message type"); + } + } + } catch (error) { + console.warn("Unhandled message:", event.data); + } + }); + } - /** - * Initialises sockets and realtime updates - */ - init() { - this.createSocket(); - } + /** + * Initialises sockets and realtime updates + */ + init() { + this.createSocket(); + } - /** - * Gets a list of extension urls - * @returns An array of script urls to import - */ - async getExtensions() { - const resp = await fetch(this.getBackendUrl() + `/extensions`, { cache: "no-store" }); - return await resp.json(); - } + /** + * Gets a list of extension urls + * @returns An array of script urls to import + */ + async getExtensions() { + const resp = await fetch(this.getBackendUrl() + `/extensions`, { cache: "no-store" }); + return await resp.json(); + } - /** - * Gets a list of embedding names - * @returns An array of script urls to import - */ - async getEmbeddings() { - const resp = await fetch(this.getBackendUrl() + "/embeddings", { cache: "no-store" }); - return await resp.json(); - } + /** + * Gets a list of embedding names + * @returns An array of script urls to import + */ + async getEmbeddings() { + const resp = await fetch(this.getBackendUrl() + "/embeddings", { cache: "no-store" }); + return await resp.json(); + } - /** - * Loads node object definitions for the graph - * @returns The node definitions - */ - async getNodeDefs() { - const resp = await fetch(this.getBackendUrl() + "/object_info", { cache: "no-store" }); - return await resp.json(); - } + /** + * Loads node object definitions for the graph + * @returns The node definitions + */ + async getNodeDefs() { + const resp = await fetch(this.getBackendUrl() + "/object_info", { cache: "no-store" }); + return await resp.json(); + } - /** - * - * @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue - * @param {object} prompt The prompt data to queue - */ - async queuePrompt(number: number, { output, workflow }) { - const body: PromptRequestBody = { - client_id: this.clientId, - prompt: output, - extra_data: { extra_pnginfo: { workflow } }, + /** + * + * @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue + * @param {object} prompt The prompt data to queue + */ + async queuePrompt(number: number, { output, workflow }) { + const body: PromptRequestBody = { + client_id: this.clientId, + prompt: output, + extra_data: { extra_pnginfo: { workflow } }, front: false, number: null - }; + }; - if (number === -1) { - body.front = true; - } else if (number != 0) { - body.number = number; - } + if (number === -1) { + body.front = true; + } else if (number != 0) { + body.number = number; + } - const res = await fetch(this.getBackendUrl() + "/prompt", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - }); + const res = await fetch(this.getBackendUrl() + "/prompt", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }); - if (res.status !== 200) { - throw { - response: await res.text(), - }; - } - } + if (res.status !== 200) { + throw { + response: await res.text(), + }; + } + } - /** - * Loads a list of items (queue or history) - * @param {string} type The type of items to load, queue or history - * @returns The items of the specified type grouped by their status - */ - async getItems(type: QueueItemType) { - if (type === "queue") { - return this.getQueue(); - } - return this.getHistory(); - } + /** + * Loads a list of items (queue or history) + * @param {string} type The type of items to load, queue or history + * @returns The items of the specified type grouped by their status + */ + async getItems(type: QueueItemType) { + if (type === "queue") { + return this.getQueue(); + } + return this.getHistory(); + } - /** - * Gets the current state of the queue - * @returns The currently running and queued items - */ - async getQueue() { - try { - const res = await fetch(this.getBackendUrl() + "/queue"); - const data = await res.json(); - return { - // Running action uses a different endpoint for cancelling - Running: data.queue_running.map((prompt) => ({ - prompt, - remove: { name: "Cancel", cb: () => this.interrupt() }, - })), - Pending: data.queue_pending.map((prompt) => ({ prompt })), - }; - } catch (error) { - console.error(error); - return { Running: [], Pending: [] }; - } - } + /** + * Gets the current state of the queue + * @returns The currently running and queued items + */ + async getQueue() { + try { + const res = await fetch(this.getBackendUrl() + "/queue"); + const data = await res.json(); + return { + // Running action uses a different endpoint for cancelling + Running: data.queue_running.map((prompt) => ({ + prompt, + remove: { name: "Cancel", cb: () => this.interrupt() }, + })), + Pending: data.queue_pending.map((prompt) => ({ prompt })), + }; + } catch (error) { + console.error(error); + return { Running: [], Pending: [] }; + } + } - /** - * Gets the prompt execution history - * @returns Prompt history including node outputs - */ - async getHistory() { - try { - const res = await fetch(this.getBackendUrl() + "/history"); - return { History: Object.values(await res.json()) }; - } catch (error) { - console.error(error); - return { History: [] }; - } - } + /** + * Gets the prompt execution history + * @returns Prompt history including node outputs + */ + async getHistory() { + try { + const res = await fetch(this.getBackendUrl() + "/history"); + return { History: Object.values(await res.json()) }; + } catch (error) { + console.error(error); + return { History: [] }; + } + } - /** - * Sends a POST request to the API - * @param {*} type The endpoint to post to - * @param {*} body Optional POST data - */ - private async postItem(type: string, body: any) { - try { - await fetch("/" + type, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: body ? JSON.stringify(body) : undefined, - }); - } catch (error) { - console.error(error); - } - } + /** + * Sends a POST request to the API + * @param {*} type The endpoint to post to + * @param {*} body Optional POST data + */ + private async postItem(type: string, body: any) { + try { + await fetch("/" + type, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: body ? JSON.stringify(body) : undefined, + }); + } catch (error) { + console.error(error); + } + } - /** - * Deletes an item from the specified list - * @param {string} type The type of item to delete, queue or history - * @param {number} id The id of the item to delete - */ - async deleteItem(type: string, id: number) { - await this.postItem(type, { delete: [id] }); - } + /** + * Deletes an item from the specified list + * @param {string} type The type of item to delete, queue or history + * @param {number} id The id of the item to delete + */ + async deleteItem(type: string, id: number) { + await this.postItem(type, { delete: [id] }); + } - /** - * Clears the specified list - * @param {string} type The type of list to clear, queue or history - */ - async clearItems(type: string) { - await this.postItem(type, { clear: true }); - } + /** + * Clears the specified list + * @param {string} type The type of list to clear, queue or history + */ + async clearItems(type: string) { + await this.postItem(type, { clear: true }); + } - /** - * Interrupts the execution of the running prompt - */ - async interrupt() { - await this.postItem("interrupt", null); - } + /** + * Interrupts the execution of the running prompt + */ + async interrupt() { + await this.postItem("interrupt", null); + } } diff --git a/src/lib/components/ComfyApp.ts b/src/lib/components/ComfyApp.ts index 38fa604..ea7be82 100644 --- a/src/lib/components/ComfyApp.ts +++ b/src/lib/components/ComfyApp.ts @@ -10,6 +10,8 @@ import type TypedEmitter from "typed-emitter"; // Import nodes import * as basic from "@litegraph-ts/nodes-basic" import * as nodes from "$lib/nodes/index" +import ComfyGraphCanvas from "$lib/ComfyGraphCanvas"; +import type ComfyGraphNode from "$lib/ComfyGraphNode"; LiteGraph.catch_exceptions = false; @@ -34,6 +36,11 @@ interface ComfyGraphNodeExecutable extends LGraphNodeExecutable { applyToGraph(workflow: SerializedLGraph, SerializedLLink, SerializedLGraphGroup>): void; } +export type Progress = { + value: number, + max: number +} + export default class ComfyApp { api: ComfyAPI; canvasEl: HTMLCanvasElement | null = null; @@ -44,6 +51,12 @@ export default class ComfyApp { nodeOutputs: Record = {}; eventBus: TypedEmitter = new EventEmitter() as TypedEmitter; + runningNodeId: number | null = null; + dragOverNode: LGraphNode | null = null; + progress: Progress | null = null; + shiftDown: boolean = false; + selectedGroupMoving: boolean = false; + private queueItems: QueueItem[] = []; private processingQueue: boolean = false; @@ -52,12 +65,9 @@ export default class ComfyApp { } async setup(): Promise { - this.addProcessMouseHandler(); - this.addProcessKeyHandler(); - this.canvasEl = document.getElementById("graph-canvas") as HTMLCanvasElement; this.lGraph = new LGraph(); - this.lCanvas = new LGraphCanvas(this.canvasEl, this.lGraph); + this.lCanvas = new ComfyGraphCanvas(this, this.canvasEl, this.lGraph); this.canvasCtx = this.canvasEl.getContext("2d"); this.addGraphLifecycleHooks(); @@ -91,12 +101,10 @@ export default class ComfyApp { // Save current workflow automatically setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.lGraph.serialize())), 1000); - // this.#addDrawNodeHandler(); - // this.#addDrawGroupsHandler(); - // this.#addApiUpdateHandlers(); + this.addApiUpdateHandlers(); this.addDropHandler(); - // this.#addPasteHandler(); - // this.#addKeyboardHandler(); + this.addPasteHandler(); + this.addKeyboardHandler(); // await this.#invokeExtensionsAsync("setup"); @@ -115,14 +123,6 @@ export default class ComfyApp { this.lCanvas.draw(true, true); } - private addProcessMouseHandler() { - - } - - private addProcessKeyHandler() { - - } - private graphOnConfigure() { console.log("Configured"); this.eventBus.emit("configured", this.lGraph); @@ -292,6 +292,82 @@ export default class ComfyApp { this.dropZone.addEventListener('drop', this.handleDrop.bind(this)); } + /** + * Adds a handler on paste that extracts and loads workflows from pasted JSON data + */ + private addPasteHandler() { + document.addEventListener("paste", (e) => { + let data = (e.clipboardData || (window as any).clipboardData).getData("text/plain"); + let workflow; + try { + data = data.slice(data.indexOf("{")); + workflow = JSON.parse(data); + } catch (err) { + try { + data = data.slice(data.indexOf("workflow\n")); + data = data.slice(data.indexOf("{")); + workflow = JSON.parse(data); + } catch (error) { } + } + + if (workflow && workflow.version && workflow.nodes && workflow.extra) { + this.loadGraphData(workflow); + } + }); + } + + /** + * Handles updates from the API socket + */ + private addApiUpdateHandlers() { + this.api.addEventListener("status", ({ detail }: CustomEvent) => { + // this.ui.setStatus(detail); + }); + + this.api.addEventListener("reconnecting", () => { + // this.ui.dialog.show("Reconnecting..."); + }); + + this.api.addEventListener("reconnected", () => { + // this.ui.dialog.close(); + }); + + this.api.addEventListener("progress", ({ detail }: CustomEvent) => { + this.progress = detail; + this.lGraph.setDirtyCanvas(true, false); + }); + + this.api.addEventListener("executing", ({ detail }: CustomEvent) => { + this.progress = null; + this.runningNodeId = detail; + this.lGraph.setDirtyCanvas(true, false); + }); + + this.api.addEventListener("executed", ({ detail }: CustomEvent) => { + this.nodeOutputs[detail.node] = detail.output; + const node = this.lGraph.getNodeById(detail.node) as ComfyGraphNode; + if (node?.onExecuted) { + node.onExecuted(detail.output); + } + }); + + this.api.init(); + } + + private addKeyboardHandler() { + window.addEventListener("keydown", (e) => { + this.shiftDown = e.shiftKey; + + // Queue prompt using ctrl or command + enter + if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) { + this.queuePrompt(e.shiftKey ? -1 : 0); + } + }); + window.addEventListener("keyup", (e) => { + this.shiftDown = e.shiftKey; + }); + } + /** * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object