Port some ComfyUI features

This commit is contained in:
space-nuko
2023-04-07 08:15:29 -05:00
parent abcdfb7345
commit 2536fbbe27
4 changed files with 503 additions and 248 deletions

174
src/lib/ComfyGraphCanvas.ts Normal file
View File

@@ -0,0 +1,174 @@
import { BuiltInSlotShape, LGraph, LGraphCanvas, LGraphNode, LiteGraph, NodeMode, type MouseEventExt, type Vector2, type Vector4 } from "@litegraph-ts/core";
import type ComfyApp from "./components/ComfyApp";
export default class ComfyGraphCanvas extends LGraphCanvas {
app: ComfyApp
constructor(
app: ComfyApp,
canvas: HTMLCanvasElement | string,
graph?: LGraph,
options: {
skip_render?: boolean;
skip_events?: boolean;
autoresize?: boolean;
viewport?: Vector4;
} = {}
) {
super(canvas, graph, options);
this.app = app;
}
override drawNodeShape(
node: LGraphNode,
ctx: CanvasRenderingContext2D,
size: Vector2,
fgColor: string,
bgColor: string,
selected: boolean,
mouseOver: boolean
): void {
super.drawNodeShape(node, ctx, size, fgColor, bgColor, selected, mouseOver);
let color = null;
if (node.id === +this.app.runningNodeId) {
color = "#0f0";
} else if (this.app.dragOverNode && node.id === this.app.dragOverNode.id) {
color = "dodgerblue";
}
if (color) {
const shape = node.shape || BuiltInSlotShape.ROUND_SHAPE;
ctx.lineWidth = 1;
ctx.globalAlpha = 0.8;
ctx.beginPath();
if (shape == BuiltInSlotShape.BOX_SHAPE)
ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
else if (shape == BuiltInSlotShape.ROUND_SHAPE || (shape == BuiltInSlotShape.CARD_SHAPE && node.flags.collapsed))
ctx.roundRect(
-6,
-6 - LiteGraph.NODE_TITLE_HEIGHT,
12 + size[0] + 1,
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
this.round_radius * 2
);
else if (shape == BuiltInSlotShape.CARD_SHAPE)
ctx.roundRect(
-6,
-6 + LiteGraph.NODE_TITLE_HEIGHT,
12 + size[0] + 1,
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
this.round_radius * 2,
2
);
else if (shape == BuiltInSlotShape.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
ctx.strokeStyle = color;
ctx.stroke();
ctx.strokeStyle = fgColor;
ctx.globalAlpha = 1;
if (this.app.progress) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (this.app.progress.value / this.app.progress.max), 6);
ctx.fillStyle = bgColor;
}
}
}
override drawNode(node: LGraphNode, ctx: CanvasRenderingContext2D): void {
var editor_alpha = this.editor_alpha;
if (node.mode === NodeMode.NEVER) { // never
this.editor_alpha = 0.4;
}
const res = super.drawNode(node, ctx);
this.editor_alpha = editor_alpha;
return res;
}
override drawGroups(canvas: HTMLCanvasElement, ctx: CanvasRenderingContext2D) {
if (!this.graph) {
return;
}
var groups = this.graph._groups;
ctx.save();
ctx.globalAlpha = 0.7 * this.editor_alpha;
for (var i = 0; i < groups.length; ++i) {
var group = groups[i];
if (!LiteGraph.overlapBounding(this.visible_area, group.bounding)) {
continue;
} //out of the visible area
ctx.fillStyle = group.color || "#335";
ctx.strokeStyle = group.color || "#335";
var pos = group.pos;
var size = group.size;
ctx.globalAlpha = 0.25 * this.editor_alpha;
ctx.beginPath();
var font_size =
group.fontSize || LiteGraph.DEFAULT_GROUP_FONT_SIZE;
ctx.rect(pos[0] + 0.5, pos[1] + 0.5, size[0], font_size * 1.4);
ctx.fill();
ctx.globalAlpha = this.editor_alpha;
}
ctx.restore();
const res = super.drawGroups(canvas, ctx);
return res;
}
/**
* Handle keypress
*
* Ctrl + M mute/unmute selected nodes
*/
override processKey(e: KeyboardEvent): boolean | undefined {
const res = super.processKey(e);
if (res === false) {
return res;
}
if (!this.graph) {
return;
}
var block_default = false;
if ("localName" in e.target && e.target.localName == "input") {
return;
}
if (e.type == "keydown") {
// Ctrl + M mute/unmute
if (e.keyCode == 77 && e.ctrlKey) {
if (this.selected_nodes) {
for (var i in this.selected_nodes) {
if (this.selected_nodes[i].mode === 2) { // never
this.selected_nodes[i].mode = 0; // always
} else {
this.selected_nodes[i].mode = 2; // never
}
}
}
block_default = true;
}
}
this.graph.change();
if (block_default) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
return res;
}
}

View File

@@ -0,0 +1,5 @@
import { LGraphNode } from "@litegraph-ts/core";
export default class ComfyGraphNode extends LGraphNode {
onExecuted?(output: any): void;
}

View File

@@ -21,24 +21,24 @@ export default class ComfyAPI extends EventTarget {
} }
override addEventListener(type: string, callback: EventListenerOrEventListenerObject | null, options?: AddEventListenerOptions | boolean) { override addEventListener(type: string, callback: EventListenerOrEventListenerObject | null, options?: AddEventListenerOptions | boolean) {
super.addEventListener(type, callback, options); super.addEventListener(type, callback, options);
this.registered.add(type); this.registered.add(type);
} }
/** /**
* Poll status for colab and other things that don't support websockets. * Poll status for colab and other things that don't support websockets.
*/ */
private pollQueue() { private pollQueue() {
setInterval(async () => { setInterval(async () => {
try { try {
const resp = await fetch(this.getBackendUrl() + "/prompt"); const resp = await fetch(this.getBackendUrl() + "/prompt");
const status = await resp.json(); const status = await resp.json();
this.dispatchEvent(new CustomEvent("status", { detail: status })); this.dispatchEvent(new CustomEvent("status", { detail: status }));
} catch (error) { } catch (error) {
this.dispatchEvent(new CustomEvent("status", { detail: null })); this.dispatchEvent(new CustomEvent("status", { detail: null }));
} }
}, 1000); }, 1000);
} }
private getBackendUrl(): string { private getBackendUrl(): string {
const hostname = this.hostname || location.hostname; const hostname = this.hostname || location.hostname;
@@ -48,243 +48,243 @@ export default class ComfyAPI extends EventTarget {
return `${window.location.protocol}//${hostname}:${port}` return `${window.location.protocol}//${hostname}:${port}`
} }
/** /**
* Creates and connects a WebSocket for realtime updates * Creates and connects a WebSocket for realtime updates
* @param {boolean} isReconnect If the socket is connection is a reconnect attempt * @param {boolean} isReconnect If the socket is connection is a reconnect attempt
*/ */
private createSocket(isReconnect: boolean = false) { private createSocket(isReconnect: boolean = false) {
if (this.socket) { if (this.socket) {
return; return;
} }
let opened = false; let opened = false;
let existingSession = sessionStorage["Comfy.SessionId"] || ""; let existingSession = sessionStorage["Comfy.SessionId"] || "";
if (existingSession) { if (existingSession) {
existingSession = "/" + existingSession; existingSession = "/" + existingSession;
} }
const hostname = this.hostname || location.host; const hostname = this.hostname || location.hostname;
const port = this.port || location.port; const port = this.port || location.port;
this.socket = new WebSocket( this.socket = new WebSocket(
`ws${window.location.protocol === "https:" ? "s" : ""}://${hostname}:${port}/ws${existingSession}` `ws${window.location.protocol === "https:" ? "s" : ""}://${hostname}:${port}/ws?clientId=${existingSession}`
); );
this.socket.addEventListener("open", () => { this.socket.addEventListener("open", () => {
opened = true; opened = true;
if (isReconnect) { if (isReconnect) {
this.dispatchEvent(new CustomEvent("reconnected")); this.dispatchEvent(new CustomEvent("reconnected"));
} }
}); });
this.socket.addEventListener("error", () => { this.socket.addEventListener("error", () => {
if (this.socket) this.socket.close(); if (this.socket) this.socket.close();
if (!isReconnect && !opened) { if (!isReconnect && !opened) {
this.pollQueue(); this.pollQueue();
} }
}); });
this.socket.addEventListener("close", () => { this.socket.addEventListener("close", () => {
setTimeout(() => { setTimeout(() => {
this.socket = null; this.socket = null;
this.createSocket(true); this.createSocket(true);
}, 300); }, 300);
if (opened) { if (opened) {
this.dispatchEvent(new CustomEvent("status", { detail: null })); this.dispatchEvent(new CustomEvent("status", { detail: null }));
this.dispatchEvent(new CustomEvent("reconnecting")); this.dispatchEvent(new CustomEvent("reconnecting"));
} }
}); });
this.socket.addEventListener("message", (event) => { this.socket.addEventListener("message", (event) => {
try { try {
const msg = JSON.parse(event.data); const msg = JSON.parse(event.data);
switch (msg.type) { switch (msg.type) {
case "status": case "status":
if (msg.data.sid) { if (msg.data.sid) {
this.clientId = msg.data.sid; this.clientId = msg.data.sid;
sessionStorage["Comfy.SessionId"] = this.clientId; sessionStorage["Comfy.SessionId"] = this.clientId;
} }
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break; break;
case "progress": case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break; break;
case "executing": case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break; break;
case "executed": case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break; break;
default: default:
if (this.registered.has(msg.type)) { if (this.registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else { } else {
throw new Error("Unknown message type"); throw new Error("Unknown message type");
} }
} }
} catch (error) { } catch (error) {
console.warn("Unhandled message:", event.data); console.warn("Unhandled message:", event.data);
} }
}); });
} }
/** /**
* Initialises sockets and realtime updates * Initialises sockets and realtime updates
*/ */
init() { init() {
this.createSocket(); this.createSocket();
} }
/** /**
* Gets a list of extension urls * Gets a list of extension urls
* @returns An array of script urls to import * @returns An array of script urls to import
*/ */
async getExtensions() { async getExtensions() {
const resp = await fetch(this.getBackendUrl() + `/extensions`, { cache: "no-store" }); const resp = await fetch(this.getBackendUrl() + `/extensions`, { cache: "no-store" });
return await resp.json(); return await resp.json();
} }
/** /**
* Gets a list of embedding names * Gets a list of embedding names
* @returns An array of script urls to import * @returns An array of script urls to import
*/ */
async getEmbeddings() { async getEmbeddings() {
const resp = await fetch(this.getBackendUrl() + "/embeddings", { cache: "no-store" }); const resp = await fetch(this.getBackendUrl() + "/embeddings", { cache: "no-store" });
return await resp.json(); return await resp.json();
} }
/** /**
* Loads node object definitions for the graph * Loads node object definitions for the graph
* @returns The node definitions * @returns The node definitions
*/ */
async getNodeDefs() { async getNodeDefs() {
const resp = await fetch(this.getBackendUrl() + "/object_info", { cache: "no-store" }); const resp = await fetch(this.getBackendUrl() + "/object_info", { cache: "no-store" });
return await resp.json(); return await resp.json();
} }
/** /**
* *
* @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue * @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue
* @param {object} prompt The prompt data to queue * @param {object} prompt The prompt data to queue
*/ */
async queuePrompt(number: number, { output, workflow }) { async queuePrompt(number: number, { output, workflow }) {
const body: PromptRequestBody = { const body: PromptRequestBody = {
client_id: this.clientId, client_id: this.clientId,
prompt: output, prompt: output,
extra_data: { extra_pnginfo: { workflow } }, extra_data: { extra_pnginfo: { workflow } },
front: false, front: false,
number: null number: null
}; };
if (number === -1) { if (number === -1) {
body.front = true; body.front = true;
} else if (number != 0) { } else if (number != 0) {
body.number = number; body.number = number;
} }
const res = await fetch(this.getBackendUrl() + "/prompt", { const res = await fetch(this.getBackendUrl() + "/prompt", {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify(body), body: JSON.stringify(body),
}); });
if (res.status !== 200) { if (res.status !== 200) {
throw { throw {
response: await res.text(), response: await res.text(),
}; };
} }
} }
/** /**
* Loads a list of items (queue or history) * Loads a list of items (queue or history)
* @param {string} type The type of items to load, queue or history * @param {string} type The type of items to load, queue or history
* @returns The items of the specified type grouped by their status * @returns The items of the specified type grouped by their status
*/ */
async getItems(type: QueueItemType) { async getItems(type: QueueItemType) {
if (type === "queue") { if (type === "queue") {
return this.getQueue(); return this.getQueue();
} }
return this.getHistory(); return this.getHistory();
} }
/** /**
* Gets the current state of the queue * Gets the current state of the queue
* @returns The currently running and queued items * @returns The currently running and queued items
*/ */
async getQueue() { async getQueue() {
try { try {
const res = await fetch(this.getBackendUrl() + "/queue"); const res = await fetch(this.getBackendUrl() + "/queue");
const data = await res.json(); const data = await res.json();
return { return {
// Running action uses a different endpoint for cancelling // Running action uses a different endpoint for cancelling
Running: data.queue_running.map((prompt) => ({ Running: data.queue_running.map((prompt) => ({
prompt, prompt,
remove: { name: "Cancel", cb: () => this.interrupt() }, remove: { name: "Cancel", cb: () => this.interrupt() },
})), })),
Pending: data.queue_pending.map((prompt) => ({ prompt })), Pending: data.queue_pending.map((prompt) => ({ prompt })),
}; };
} catch (error) { } catch (error) {
console.error(error); console.error(error);
return { Running: [], Pending: [] }; return { Running: [], Pending: [] };
} }
} }
/** /**
* Gets the prompt execution history * Gets the prompt execution history
* @returns Prompt history including node outputs * @returns Prompt history including node outputs
*/ */
async getHistory() { async getHistory() {
try { try {
const res = await fetch(this.getBackendUrl() + "/history"); const res = await fetch(this.getBackendUrl() + "/history");
return { History: Object.values(await res.json()) }; return { History: Object.values(await res.json()) };
} catch (error) { } catch (error) {
console.error(error); console.error(error);
return { History: [] }; return { History: [] };
} }
} }
/** /**
* Sends a POST request to the API * Sends a POST request to the API
* @param {*} type The endpoint to post to * @param {*} type The endpoint to post to
* @param {*} body Optional POST data * @param {*} body Optional POST data
*/ */
private async postItem(type: string, body: any) { private async postItem(type: string, body: any) {
try { try {
await fetch("/" + type, { await fetch("/" + type, {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: body ? JSON.stringify(body) : undefined, body: body ? JSON.stringify(body) : undefined,
}); });
} catch (error) { } catch (error) {
console.error(error); console.error(error);
} }
} }
/** /**
* Deletes an item from the specified list * Deletes an item from the specified list
* @param {string} type The type of item to delete, queue or history * @param {string} type The type of item to delete, queue or history
* @param {number} id The id of the item to delete * @param {number} id The id of the item to delete
*/ */
async deleteItem(type: string, id: number) { async deleteItem(type: string, id: number) {
await this.postItem(type, { delete: [id] }); await this.postItem(type, { delete: [id] });
} }
/** /**
* Clears the specified list * Clears the specified list
* @param {string} type The type of list to clear, queue or history * @param {string} type The type of list to clear, queue or history
*/ */
async clearItems(type: string) { async clearItems(type: string) {
await this.postItem(type, { clear: true }); await this.postItem(type, { clear: true });
} }
/** /**
* Interrupts the execution of the running prompt * Interrupts the execution of the running prompt
*/ */
async interrupt() { async interrupt() {
await this.postItem("interrupt", null); await this.postItem("interrupt", null);
} }
} }

View File

@@ -10,6 +10,8 @@ import type TypedEmitter from "typed-emitter";
// Import nodes // Import nodes
import * as basic from "@litegraph-ts/nodes-basic" import * as basic from "@litegraph-ts/nodes-basic"
import * as nodes from "$lib/nodes/index" import * as nodes from "$lib/nodes/index"
import ComfyGraphCanvas from "$lib/ComfyGraphCanvas";
import type ComfyGraphNode from "$lib/ComfyGraphNode";
LiteGraph.catch_exceptions = false; LiteGraph.catch_exceptions = false;
@@ -34,6 +36,11 @@ interface ComfyGraphNodeExecutable extends LGraphNodeExecutable {
applyToGraph(workflow: SerializedLGraph<SerializedLGraphNode<LGraphNode>, SerializedLLink, SerializedLGraphGroup>): void; applyToGraph(workflow: SerializedLGraph<SerializedLGraphNode<LGraphNode>, SerializedLLink, SerializedLGraphGroup>): void;
} }
export type Progress = {
value: number,
max: number
}
export default class ComfyApp { export default class ComfyApp {
api: ComfyAPI; api: ComfyAPI;
canvasEl: HTMLCanvasElement | null = null; canvasEl: HTMLCanvasElement | null = null;
@@ -44,6 +51,12 @@ export default class ComfyApp {
nodeOutputs: Record<string, any> = {}; nodeOutputs: Record<string, any> = {};
eventBus: TypedEmitter<ComfyAppEvents> = new EventEmitter() as TypedEmitter<ComfyAppEvents>; eventBus: TypedEmitter<ComfyAppEvents> = new EventEmitter() as TypedEmitter<ComfyAppEvents>;
runningNodeId: number | null = null;
dragOverNode: LGraphNode | null = null;
progress: Progress | null = null;
shiftDown: boolean = false;
selectedGroupMoving: boolean = false;
private queueItems: QueueItem[] = []; private queueItems: QueueItem[] = [];
private processingQueue: boolean = false; private processingQueue: boolean = false;
@@ -52,12 +65,9 @@ export default class ComfyApp {
} }
async setup(): Promise<void> { async setup(): Promise<void> {
this.addProcessMouseHandler();
this.addProcessKeyHandler();
this.canvasEl = document.getElementById("graph-canvas") as HTMLCanvasElement; this.canvasEl = document.getElementById("graph-canvas") as HTMLCanvasElement;
this.lGraph = new LGraph(); this.lGraph = new LGraph();
this.lCanvas = new LGraphCanvas(this.canvasEl, this.lGraph); this.lCanvas = new ComfyGraphCanvas(this, this.canvasEl, this.lGraph);
this.canvasCtx = this.canvasEl.getContext("2d"); this.canvasCtx = this.canvasEl.getContext("2d");
this.addGraphLifecycleHooks(); this.addGraphLifecycleHooks();
@@ -91,12 +101,10 @@ export default class ComfyApp {
// Save current workflow automatically // Save current workflow automatically
setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.lGraph.serialize())), 1000); setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.lGraph.serialize())), 1000);
// this.#addDrawNodeHandler(); this.addApiUpdateHandlers();
// this.#addDrawGroupsHandler();
// this.#addApiUpdateHandlers();
this.addDropHandler(); this.addDropHandler();
// this.#addPasteHandler(); this.addPasteHandler();
// this.#addKeyboardHandler(); this.addKeyboardHandler();
// await this.#invokeExtensionsAsync("setup"); // await this.#invokeExtensionsAsync("setup");
@@ -115,14 +123,6 @@ export default class ComfyApp {
this.lCanvas.draw(true, true); this.lCanvas.draw(true, true);
} }
private addProcessMouseHandler() {
}
private addProcessKeyHandler() {
}
private graphOnConfigure() { private graphOnConfigure() {
console.log("Configured"); console.log("Configured");
this.eventBus.emit("configured", this.lGraph); this.eventBus.emit("configured", this.lGraph);
@@ -292,6 +292,82 @@ export default class ComfyApp {
this.dropZone.addEventListener('drop', this.handleDrop.bind(this)); this.dropZone.addEventListener('drop', this.handleDrop.bind(this));
} }
/**
* Adds a handler on paste that extracts and loads workflows from pasted JSON data
*/
private addPasteHandler() {
document.addEventListener("paste", (e) => {
let data = (e.clipboardData || (window as any).clipboardData).getData("text/plain");
let workflow;
try {
data = data.slice(data.indexOf("{"));
workflow = JSON.parse(data);
} catch (err) {
try {
data = data.slice(data.indexOf("workflow\n"));
data = data.slice(data.indexOf("{"));
workflow = JSON.parse(data);
} catch (error) { }
}
if (workflow && workflow.version && workflow.nodes && workflow.extra) {
this.loadGraphData(workflow);
}
});
}
/**
* Handles updates from the API socket
*/
private addApiUpdateHandlers() {
this.api.addEventListener("status", ({ detail }: CustomEvent) => {
// this.ui.setStatus(detail);
});
this.api.addEventListener("reconnecting", () => {
// this.ui.dialog.show("Reconnecting...");
});
this.api.addEventListener("reconnected", () => {
// this.ui.dialog.close();
});
this.api.addEventListener("progress", ({ detail }: CustomEvent) => {
this.progress = detail;
this.lGraph.setDirtyCanvas(true, false);
});
this.api.addEventListener("executing", ({ detail }: CustomEvent) => {
this.progress = null;
this.runningNodeId = detail;
this.lGraph.setDirtyCanvas(true, false);
});
this.api.addEventListener("executed", ({ detail }: CustomEvent) => {
this.nodeOutputs[detail.node] = detail.output;
const node = this.lGraph.getNodeById(detail.node) as ComfyGraphNode;
if (node?.onExecuted) {
node.onExecuted(detail.output);
}
});
this.api.init();
}
private addKeyboardHandler() {
window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey;
// Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0);
}
});
window.addEventListener("keyup", (e) => {
this.shiftDown = e.shiftKey;
});
}
/** /**
* Populates the graph with the specified workflow data * Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object * @param {*} graphData A serialized graph object