surena26 commited on
Commit
badcd5b
·
verified ·
1 Parent(s): 27c40cb

Upload ComfyUI/server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ComfyUI/server.py +655 -0
ComfyUI/server.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ import traceback
5
+
6
+ import nodes
7
+ import folder_paths
8
+ import execution
9
+ import uuid
10
+ import urllib
11
+ import json
12
+ import glob
13
+ import struct
14
+ import ssl
15
+ from PIL import Image, ImageOps
16
+ from PIL.PngImagePlugin import PngInfo
17
+ from io import BytesIO
18
+
19
+ import aiohttp
20
+ from aiohttp import web
21
+ import logging
22
+
23
+ import mimetypes
24
+ from comfy.cli_args import args
25
+ import comfy.utils
26
+ import comfy.model_management
27
+
28
+ from app.user_manager import UserManager
29
+
30
+ class BinaryEventTypes:
31
+ PREVIEW_IMAGE = 1
32
+ UNENCODED_PREVIEW_IMAGE = 2
33
+
34
+ async def send_socket_catch_exception(function, message):
35
+ try:
36
+ await function(message)
37
+ except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
38
+ logging.warning("send error: {}".format(err))
39
+
40
+ @web.middleware
41
+ async def cache_control(request: web.Request, handler):
42
+ response: web.Response = await handler(request)
43
+ if request.path.endswith('.js') or request.path.endswith('.css'):
44
+ response.headers.setdefault('Cache-Control', 'no-cache')
45
+ return response
46
+
47
+ def create_cors_middleware(allowed_origin: str):
48
+ @web.middleware
49
+ async def cors_middleware(request: web.Request, handler):
50
+ if request.method == "OPTIONS":
51
+ # Pre-flight request. Reply successfully:
52
+ response = web.Response()
53
+ else:
54
+ response = await handler(request)
55
+
56
+ response.headers['Access-Control-Allow-Origin'] = allowed_origin
57
+ response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
58
+ response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
59
+ response.headers['Access-Control-Allow-Credentials'] = 'true'
60
+ return response
61
+
62
+ return cors_middleware
63
+
64
+ class PromptServer():
65
+ def __init__(self, loop):
66
+ PromptServer.instance = self
67
+
68
+ mimetypes.init()
69
+ mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
70
+
71
+ self.user_manager = UserManager()
72
+ self.supports = ["custom_nodes_from_web"]
73
+ self.prompt_queue = None
74
+ self.loop = loop
75
+ self.messages = asyncio.Queue()
76
+ self.number = 0
77
+
78
+ middlewares = [cache_control]
79
+ if args.enable_cors_header:
80
+ middlewares.append(create_cors_middleware(args.enable_cors_header))
81
+
82
+ max_upload_size = round(args.max_upload_size * 1024 * 1024)
83
+ self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
84
+ self.sockets = dict()
85
+ self.web_root = os.path.join(os.path.dirname(
86
+ os.path.realpath(__file__)), "web")
87
+ routes = web.RouteTableDef()
88
+ self.routes = routes
89
+ self.last_node_id = None
90
+ self.client_id = None
91
+
92
+ self.on_prompt_handlers = []
93
+
94
+ @routes.get('/ws')
95
+ async def websocket_handler(request):
96
+ ws = web.WebSocketResponse()
97
+ await ws.prepare(request)
98
+ sid = request.rel_url.query.get('clientId', '')
99
+ if sid:
100
+ # Reusing existing session, remove old
101
+ self.sockets.pop(sid, None)
102
+ else:
103
+ sid = uuid.uuid4().hex
104
+
105
+ self.sockets[sid] = ws
106
+
107
+ try:
108
+ # Send initial state to the new client
109
+ await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
110
+ # On reconnect if we are the currently executing client send the current node
111
+ if self.client_id == sid and self.last_node_id is not None:
112
+ await self.send("executing", { "node": self.last_node_id }, sid)
113
+
114
+ async for msg in ws:
115
+ if msg.type == aiohttp.WSMsgType.ERROR:
116
+ logging.warning('ws connection closed with exception %s' % ws.exception())
117
+ finally:
118
+ self.sockets.pop(sid, None)
119
+ return ws
120
+
121
+ @routes.get("/")
122
+ async def get_root(request):
123
+ return web.FileResponse(os.path.join(self.web_root, "index.html"))
124
+
125
+ @routes.get("/embeddings")
126
+ def get_embeddings(self):
127
+ embeddings = folder_paths.get_filename_list("embeddings")
128
+ return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
129
+
130
+ @routes.get("/extensions")
131
+ async def get_extensions(request):
132
+ files = glob.glob(os.path.join(
133
+ glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
134
+
135
+ extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
136
+
137
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
138
+ files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
139
+ extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
140
+ name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
141
+
142
+ return web.json_response(extensions)
143
+
144
+ def get_dir_by_type(dir_type):
145
+ if dir_type is None:
146
+ dir_type = "input"
147
+
148
+ if dir_type == "input":
149
+ type_dir = folder_paths.get_input_directory()
150
+ elif dir_type == "temp":
151
+ type_dir = folder_paths.get_temp_directory()
152
+ elif dir_type == "output":
153
+ type_dir = folder_paths.get_output_directory()
154
+
155
+ return type_dir, dir_type
156
+
157
+ def image_upload(post, image_save_function=None):
158
+ image = post.get("image")
159
+ overwrite = post.get("overwrite")
160
+
161
+ image_upload_type = post.get("type")
162
+ upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
163
+
164
+ if image and image.file:
165
+ filename = image.filename
166
+ if not filename:
167
+ return web.Response(status=400)
168
+
169
+ subfolder = post.get("subfolder", "")
170
+ full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
171
+ filepath = os.path.abspath(os.path.join(full_output_folder, filename))
172
+
173
+ if os.path.commonpath((upload_dir, filepath)) != upload_dir:
174
+ return web.Response(status=400)
175
+
176
+ if not os.path.exists(full_output_folder):
177
+ os.makedirs(full_output_folder)
178
+
179
+ split = os.path.splitext(filename)
180
+
181
+ if overwrite is not None and (overwrite == "true" or overwrite == "1"):
182
+ pass
183
+ else:
184
+ i = 1
185
+ while os.path.exists(filepath):
186
+ filename = f"{split[0]} ({i}){split[1]}"
187
+ filepath = os.path.join(full_output_folder, filename)
188
+ i += 1
189
+
190
+ if image_save_function is not None:
191
+ image_save_function(image, post, filepath)
192
+ else:
193
+ with open(filepath, "wb") as f:
194
+ f.write(image.file.read())
195
+
196
+ return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
197
+ else:
198
+ return web.Response(status=400)
199
+
200
+ @routes.post("/upload/image")
201
+ async def upload_image(request):
202
+ post = await request.post()
203
+ return image_upload(post)
204
+
205
+
206
+ @routes.post("/upload/mask")
207
+ async def upload_mask(request):
208
+ post = await request.post()
209
+
210
+ def image_save_function(image, post, filepath):
211
+ original_ref = json.loads(post.get("original_ref"))
212
+ filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
213
+
214
+ # validation for security: prevent accessing arbitrary path
215
+ if filename[0] == '/' or '..' in filename:
216
+ return web.Response(status=400)
217
+
218
+ if output_dir is None:
219
+ type = original_ref.get("type", "output")
220
+ output_dir = folder_paths.get_directory_by_type(type)
221
+
222
+ if output_dir is None:
223
+ return web.Response(status=400)
224
+
225
+ if original_ref.get("subfolder", "") != "":
226
+ full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
227
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
228
+ return web.Response(status=403)
229
+ output_dir = full_output_dir
230
+
231
+ file = os.path.join(output_dir, filename)
232
+
233
+ if os.path.isfile(file):
234
+ with Image.open(file) as original_pil:
235
+ metadata = PngInfo()
236
+ if hasattr(original_pil,'text'):
237
+ for key in original_pil.text:
238
+ metadata.add_text(key, original_pil.text[key])
239
+ original_pil = original_pil.convert('RGBA')
240
+ mask_pil = Image.open(image.file).convert('RGBA')
241
+
242
+ # alpha copy
243
+ new_alpha = mask_pil.getchannel('A')
244
+ original_pil.putalpha(new_alpha)
245
+ original_pil.save(filepath, compress_level=4, pnginfo=metadata)
246
+
247
+ return image_upload(post, image_save_function)
248
+
249
+ @routes.get("/view")
250
+ async def view_image(request):
251
+ if "filename" in request.rel_url.query:
252
+ filename = request.rel_url.query["filename"]
253
+ filename,output_dir = folder_paths.annotated_filepath(filename)
254
+
255
+ # validation for security: prevent accessing arbitrary path
256
+ if filename[0] == '/' or '..' in filename:
257
+ return web.Response(status=400)
258
+
259
+ if output_dir is None:
260
+ type = request.rel_url.query.get("type", "output")
261
+ output_dir = folder_paths.get_directory_by_type(type)
262
+
263
+ if output_dir is None:
264
+ return web.Response(status=400)
265
+
266
+ if "subfolder" in request.rel_url.query:
267
+ full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
268
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
269
+ return web.Response(status=403)
270
+ output_dir = full_output_dir
271
+
272
+ filename = os.path.basename(filename)
273
+ file = os.path.join(output_dir, filename)
274
+
275
+ if os.path.isfile(file):
276
+ if 'preview' in request.rel_url.query:
277
+ with Image.open(file) as img:
278
+ preview_info = request.rel_url.query['preview'].split(';')
279
+ image_format = preview_info[0]
280
+ if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
281
+ image_format = 'webp'
282
+
283
+ quality = 90
284
+ if preview_info[-1].isdigit():
285
+ quality = int(preview_info[-1])
286
+
287
+ buffer = BytesIO()
288
+ if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
289
+ img = img.convert("RGB")
290
+ img.save(buffer, format=image_format, quality=quality)
291
+ buffer.seek(0)
292
+
293
+ return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
294
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
295
+
296
+ if 'channel' not in request.rel_url.query:
297
+ channel = 'rgba'
298
+ else:
299
+ channel = request.rel_url.query["channel"]
300
+
301
+ if channel == 'rgb':
302
+ with Image.open(file) as img:
303
+ if img.mode == "RGBA":
304
+ r, g, b, a = img.split()
305
+ new_img = Image.merge('RGB', (r, g, b))
306
+ else:
307
+ new_img = img.convert("RGB")
308
+
309
+ buffer = BytesIO()
310
+ new_img.save(buffer, format='PNG')
311
+ buffer.seek(0)
312
+
313
+ return web.Response(body=buffer.read(), content_type='image/png',
314
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
315
+
316
+ elif channel == 'a':
317
+ with Image.open(file) as img:
318
+ if img.mode == "RGBA":
319
+ _, _, _, a = img.split()
320
+ else:
321
+ a = Image.new('L', img.size, 255)
322
+
323
+ # alpha img
324
+ alpha_img = Image.new('RGBA', img.size)
325
+ alpha_img.putalpha(a)
326
+ alpha_buffer = BytesIO()
327
+ alpha_img.save(alpha_buffer, format='PNG')
328
+ alpha_buffer.seek(0)
329
+
330
+ return web.Response(body=alpha_buffer.read(), content_type='image/png',
331
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
332
+ else:
333
+ return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
334
+
335
+ return web.Response(status=404)
336
+
337
+ @routes.get("/view_metadata/{folder_name}")
338
+ async def view_metadata(request):
339
+ folder_name = request.match_info.get("folder_name", None)
340
+ if folder_name is None:
341
+ return web.Response(status=404)
342
+ if not "filename" in request.rel_url.query:
343
+ return web.Response(status=404)
344
+
345
+ filename = request.rel_url.query["filename"]
346
+ if not filename.endswith(".safetensors"):
347
+ return web.Response(status=404)
348
+
349
+ safetensors_path = folder_paths.get_full_path(folder_name, filename)
350
+ if safetensors_path is None:
351
+ return web.Response(status=404)
352
+ out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
353
+ if out is None:
354
+ return web.Response(status=404)
355
+ dt = json.loads(out)
356
+ if not "__metadata__" in dt:
357
+ return web.Response(status=404)
358
+ return web.json_response(dt["__metadata__"])
359
+
360
+ @routes.get("/system_stats")
361
+ async def get_queue(request):
362
+ device = comfy.model_management.get_torch_device()
363
+ device_name = comfy.model_management.get_torch_device_name(device)
364
+ vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
365
+ vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
366
+ system_stats = {
367
+ "system": {
368
+ "os": os.name,
369
+ "python_version": sys.version,
370
+ "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
371
+ },
372
+ "devices": [
373
+ {
374
+ "name": device_name,
375
+ "type": device.type,
376
+ "index": device.index,
377
+ "vram_total": vram_total,
378
+ "vram_free": vram_free,
379
+ "torch_vram_total": torch_vram_total,
380
+ "torch_vram_free": torch_vram_free,
381
+ }
382
+ ]
383
+ }
384
+ return web.json_response(system_stats)
385
+
386
+ @routes.get("/prompt")
387
+ async def get_prompt(request):
388
+ return web.json_response(self.get_queue_info())
389
+
390
+ def node_info(node_class):
391
+ obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
392
+ info = {}
393
+ info['input'] = obj_class.INPUT_TYPES()
394
+ info['output'] = obj_class.RETURN_TYPES
395
+ info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
396
+ info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
397
+ info['name'] = node_class
398
+ info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
399
+ info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
400
+ info['category'] = 'sd'
401
+ if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
402
+ info['output_node'] = True
403
+ else:
404
+ info['output_node'] = False
405
+
406
+ if hasattr(obj_class, 'CATEGORY'):
407
+ info['category'] = obj_class.CATEGORY
408
+ return info
409
+
410
+ @routes.get("/object_info")
411
+ async def get_object_info(request):
412
+ out = {}
413
+ for x in nodes.NODE_CLASS_MAPPINGS:
414
+ try:
415
+ out[x] = node_info(x)
416
+ except Exception as e:
417
+ logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
418
+ logging.error(traceback.format_exc())
419
+ return web.json_response(out)
420
+
421
+ @routes.get("/object_info/{node_class}")
422
+ async def get_object_info_node(request):
423
+ node_class = request.match_info.get("node_class", None)
424
+ out = {}
425
+ if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
426
+ out[node_class] = node_info(node_class)
427
+ return web.json_response(out)
428
+
429
+ @routes.get("/history")
430
+ async def get_history(request):
431
+ max_items = request.rel_url.query.get("max_items", None)
432
+ if max_items is not None:
433
+ max_items = int(max_items)
434
+ return web.json_response(self.prompt_queue.get_history(max_items=max_items))
435
+
436
+ @routes.get("/history/{prompt_id}")
437
+ async def get_history(request):
438
+ prompt_id = request.match_info.get("prompt_id", None)
439
+ return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
440
+
441
+ @routes.get("/queue")
442
+ async def get_queue(request):
443
+ queue_info = {}
444
+ current_queue = self.prompt_queue.get_current_queue()
445
+ queue_info['queue_running'] = current_queue[0]
446
+ queue_info['queue_pending'] = current_queue[1]
447
+ return web.json_response(queue_info)
448
+
449
+ @routes.post("/prompt")
450
+ async def post_prompt(request):
451
+ logging.info("got prompt")
452
+ resp_code = 200
453
+ out_string = ""
454
+ json_data = await request.json()
455
+ json_data = self.trigger_on_prompt(json_data)
456
+
457
+ if "number" in json_data:
458
+ number = float(json_data['number'])
459
+ else:
460
+ number = self.number
461
+ if "front" in json_data:
462
+ if json_data['front']:
463
+ number = -number
464
+
465
+ self.number += 1
466
+
467
+ if "prompt" in json_data:
468
+ prompt = json_data["prompt"]
469
+ valid = execution.validate_prompt(prompt)
470
+ extra_data = {}
471
+ if "extra_data" in json_data:
472
+ extra_data = json_data["extra_data"]
473
+
474
+ if "client_id" in json_data:
475
+ extra_data["client_id"] = json_data["client_id"]
476
+ if valid[0]:
477
+ prompt_id = str(uuid.uuid4())
478
+ outputs_to_execute = valid[2]
479
+ self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
480
+ response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
481
+ return web.json_response(response)
482
+ else:
483
+ logging.warning("invalid prompt: {}".format(valid[1]))
484
+ return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
485
+ else:
486
+ return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
487
+
488
+ @routes.post("/queue")
489
+ async def post_queue(request):
490
+ json_data = await request.json()
491
+ if "clear" in json_data:
492
+ if json_data["clear"]:
493
+ self.prompt_queue.wipe_queue()
494
+ if "delete" in json_data:
495
+ to_delete = json_data['delete']
496
+ for id_to_delete in to_delete:
497
+ delete_func = lambda a: a[1] == id_to_delete
498
+ self.prompt_queue.delete_queue_item(delete_func)
499
+
500
+ return web.Response(status=200)
501
+
502
+ @routes.post("/interrupt")
503
+ async def post_interrupt(request):
504
+ nodes.interrupt_processing()
505
+ return web.Response(status=200)
506
+
507
+ @routes.post("/free")
508
+ async def post_free(request):
509
+ json_data = await request.json()
510
+ unload_models = json_data.get("unload_models", False)
511
+ free_memory = json_data.get("free_memory", False)
512
+ if unload_models:
513
+ self.prompt_queue.set_flag("unload_models", unload_models)
514
+ if free_memory:
515
+ self.prompt_queue.set_flag("free_memory", free_memory)
516
+ return web.Response(status=200)
517
+
518
+ @routes.post("/history")
519
+ async def post_history(request):
520
+ json_data = await request.json()
521
+ if "clear" in json_data:
522
+ if json_data["clear"]:
523
+ self.prompt_queue.wipe_history()
524
+ if "delete" in json_data:
525
+ to_delete = json_data['delete']
526
+ for id_to_delete in to_delete:
527
+ self.prompt_queue.delete_history_item(id_to_delete)
528
+
529
+ return web.Response(status=200)
530
+
531
+ def add_routes(self):
532
+ self.user_manager.add_routes(self.routes)
533
+ self.app.add_routes(self.routes)
534
+
535
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
536
+ self.app.add_routes([
537
+ web.static('/extensions/' + urllib.parse.quote(name), dir),
538
+ ])
539
+
540
+ self.app.add_routes([
541
+ web.static('/', self.web_root),
542
+ ])
543
+
544
+ def get_queue_info(self):
545
+ prompt_info = {}
546
+ exec_info = {}
547
+ exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
548
+ prompt_info['exec_info'] = exec_info
549
+ return prompt_info
550
+
551
+ async def send(self, event, data, sid=None):
552
+ if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
553
+ await self.send_image(data, sid=sid)
554
+ elif isinstance(data, (bytes, bytearray)):
555
+ await self.send_bytes(event, data, sid)
556
+ else:
557
+ await self.send_json(event, data, sid)
558
+
559
+ def encode_bytes(self, event, data):
560
+ if not isinstance(event, int):
561
+ raise RuntimeError(f"Binary event types must be integers, got {event}")
562
+
563
+ packed = struct.pack(">I", event)
564
+ message = bytearray(packed)
565
+ message.extend(data)
566
+ return message
567
+
568
+ async def send_image(self, image_data, sid=None):
569
+ image_type = image_data[0]
570
+ image = image_data[1]
571
+ max_size = image_data[2]
572
+ if max_size is not None:
573
+ if hasattr(Image, 'Resampling'):
574
+ resampling = Image.Resampling.BILINEAR
575
+ else:
576
+ resampling = Image.ANTIALIAS
577
+
578
+ image = ImageOps.contain(image, (max_size, max_size), resampling)
579
+ type_num = 1
580
+ if image_type == "JPEG":
581
+ type_num = 1
582
+ elif image_type == "PNG":
583
+ type_num = 2
584
+
585
+ bytesIO = BytesIO()
586
+ header = struct.pack(">I", type_num)
587
+ bytesIO.write(header)
588
+ image.save(bytesIO, format=image_type, quality=95, compress_level=1)
589
+ preview_bytes = bytesIO.getvalue()
590
+ await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
591
+
592
+ async def send_bytes(self, event, data, sid=None):
593
+ message = self.encode_bytes(event, data)
594
+
595
+ if sid is None:
596
+ sockets = list(self.sockets.values())
597
+ for ws in sockets:
598
+ await send_socket_catch_exception(ws.send_bytes, message)
599
+ elif sid in self.sockets:
600
+ await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
601
+
602
+ async def send_json(self, event, data, sid=None):
603
+ message = {"type": event, "data": data}
604
+
605
+ if sid is None:
606
+ sockets = list(self.sockets.values())
607
+ for ws in sockets:
608
+ await send_socket_catch_exception(ws.send_json, message)
609
+ elif sid in self.sockets:
610
+ await send_socket_catch_exception(self.sockets[sid].send_json, message)
611
+
612
+ def send_sync(self, event, data, sid=None):
613
+ self.loop.call_soon_threadsafe(
614
+ self.messages.put_nowait, (event, data, sid))
615
+
616
+ def queue_updated(self):
617
+ self.send_sync("status", { "status": self.get_queue_info() })
618
+
619
+ async def publish_loop(self):
620
+ while True:
621
+ msg = await self.messages.get()
622
+ await self.send(*msg)
623
+
624
+ async def start(self, address, port, verbose=True, call_on_start=None):
625
+ runner = web.AppRunner(self.app, access_log=None)
626
+ await runner.setup()
627
+ ssl_ctx = None
628
+ scheme = "http"
629
+ if args.tls_keyfile and args.tls_certfile:
630
+ ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
631
+ ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
632
+ keyfile=args.tls_keyfile)
633
+ scheme = "https"
634
+
635
+ site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
636
+ await site.start()
637
+
638
+ if verbose:
639
+ logging.info("Starting server\n")
640
+ logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
641
+ if call_on_start is not None:
642
+ call_on_start(scheme, address, port)
643
+
644
+ def add_on_prompt_handler(self, handler):
645
+ self.on_prompt_handlers.append(handler)
646
+
647
+ def trigger_on_prompt(self, json_data):
648
+ for handler in self.on_prompt_handlers:
649
+ try:
650
+ json_data = handler(json_data)
651
+ except Exception as e:
652
+ logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
653
+ logging.warning(traceback.format_exc())
654
+
655
+ return json_data