surena26 commited on
Commit
4a4acd7
·
verified ·
1 Parent(s): 7595c28

Upload ComfyUI/execution.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ComfyUI/execution.py +835 -0
ComfyUI/execution.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import copy
3
+ import logging
4
+ import threading
5
+ import heapq
6
+ import traceback
7
+ import inspect
8
+ from typing import List, Literal, NamedTuple, Optional
9
+
10
+ import torch
11
+ import nodes
12
+
13
+ import comfy.model_management
14
+
15
+ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
16
+ valid_inputs = class_def.INPUT_TYPES()
17
+ input_data_all = {}
18
+ for x in inputs:
19
+ input_data = inputs[x]
20
+ if isinstance(input_data, list):
21
+ input_unique_id = input_data[0]
22
+ output_index = input_data[1]
23
+ if input_unique_id not in outputs:
24
+ input_data_all[x] = (None,)
25
+ continue
26
+ obj = outputs[input_unique_id][output_index]
27
+ input_data_all[x] = obj
28
+ else:
29
+ if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
30
+ input_data_all[x] = [input_data]
31
+
32
+ if "hidden" in valid_inputs:
33
+ h = valid_inputs["hidden"]
34
+ for x in h:
35
+ if h[x] == "PROMPT":
36
+ input_data_all[x] = [prompt]
37
+ if h[x] == "EXTRA_PNGINFO":
38
+ input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
39
+ if h[x] == "UNIQUE_ID":
40
+ input_data_all[x] = [unique_id]
41
+ return input_data_all
42
+
43
+ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
44
+ # check if node wants the lists
45
+ input_is_list = False
46
+ if hasattr(obj, "INPUT_IS_LIST"):
47
+ input_is_list = obj.INPUT_IS_LIST
48
+
49
+ if len(input_data_all) == 0:
50
+ max_len_input = 0
51
+ else:
52
+ max_len_input = max([len(x) for x in input_data_all.values()])
53
+
54
+ # get a slice of inputs, repeat last input when list isn't long enough
55
+ def slice_dict(d, i):
56
+ d_new = dict()
57
+ for k,v in d.items():
58
+ d_new[k] = v[i if len(v) > i else -1]
59
+ return d_new
60
+
61
+ results = []
62
+ if input_is_list:
63
+ if allow_interrupt:
64
+ nodes.before_node_execution()
65
+ results.append(getattr(obj, func)(**input_data_all))
66
+ elif max_len_input == 0:
67
+ if allow_interrupt:
68
+ nodes.before_node_execution()
69
+ results.append(getattr(obj, func)())
70
+ else:
71
+ for i in range(max_len_input):
72
+ if allow_interrupt:
73
+ nodes.before_node_execution()
74
+ results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
75
+ return results
76
+
77
+ def get_output_data(obj, input_data_all):
78
+
79
+ results = []
80
+ uis = []
81
+ return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
82
+
83
+ for r in return_values:
84
+ if isinstance(r, dict):
85
+ if 'ui' in r:
86
+ uis.append(r['ui'])
87
+ if 'result' in r:
88
+ results.append(r['result'])
89
+ else:
90
+ results.append(r)
91
+
92
+ output = []
93
+ if len(results) > 0:
94
+ # check which outputs need concatenating
95
+ output_is_list = [False] * len(results[0])
96
+ if hasattr(obj, "OUTPUT_IS_LIST"):
97
+ output_is_list = obj.OUTPUT_IS_LIST
98
+
99
+ # merge node execution results
100
+ for i, is_list in zip(range(len(results[0])), output_is_list):
101
+ if is_list:
102
+ output.append([x for o in results for x in o[i]])
103
+ else:
104
+ output.append([o[i] for o in results])
105
+
106
+ ui = dict()
107
+ if len(uis) > 0:
108
+ ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
109
+ return output, ui
110
+
111
+ def format_value(x):
112
+ if x is None:
113
+ return None
114
+ elif isinstance(x, (int, float, bool, str)):
115
+ return x
116
+ else:
117
+ return str(x)
118
+
119
+ def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
120
+ unique_id = current_item
121
+ inputs = prompt[unique_id]['inputs']
122
+ class_type = prompt[unique_id]['class_type']
123
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
124
+ if unique_id in outputs:
125
+ return (True, None, None)
126
+
127
+ for x in inputs:
128
+ input_data = inputs[x]
129
+
130
+ if isinstance(input_data, list):
131
+ input_unique_id = input_data[0]
132
+ output_index = input_data[1]
133
+ if input_unique_id not in outputs:
134
+ result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
135
+ if result[0] is not True:
136
+ # Another node failed further upstream
137
+ return result
138
+
139
+ input_data_all = None
140
+ try:
141
+ input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
142
+ if server.client_id is not None:
143
+ server.last_node_id = unique_id
144
+ server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
145
+
146
+ obj = object_storage.get((unique_id, class_type), None)
147
+ if obj is None:
148
+ obj = class_def()
149
+ object_storage[(unique_id, class_type)] = obj
150
+
151
+ output_data, output_ui = get_output_data(obj, input_data_all)
152
+ outputs[unique_id] = output_data
153
+ if len(output_ui) > 0:
154
+ outputs_ui[unique_id] = output_ui
155
+ if server.client_id is not None:
156
+ server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
157
+ except comfy.model_management.InterruptProcessingException as iex:
158
+ logging.info("Processing interrupted")
159
+
160
+ # skip formatting inputs/outputs
161
+ error_details = {
162
+ "node_id": unique_id,
163
+ }
164
+
165
+ return (False, error_details, iex)
166
+ except Exception as ex:
167
+ typ, _, tb = sys.exc_info()
168
+ exception_type = full_type_name(typ)
169
+ input_data_formatted = {}
170
+ if input_data_all is not None:
171
+ input_data_formatted = {}
172
+ for name, inputs in input_data_all.items():
173
+ input_data_formatted[name] = [format_value(x) for x in inputs]
174
+
175
+ output_data_formatted = {}
176
+ for node_id, node_outputs in outputs.items():
177
+ output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
178
+
179
+ logging.error(f"!!! Exception during processing!!! {ex}")
180
+ logging.error(traceback.format_exc())
181
+
182
+ error_details = {
183
+ "node_id": unique_id,
184
+ "exception_message": str(ex),
185
+ "exception_type": exception_type,
186
+ "traceback": traceback.format_tb(tb),
187
+ "current_inputs": input_data_formatted,
188
+ "current_outputs": output_data_formatted
189
+ }
190
+ return (False, error_details, ex)
191
+
192
+ executed.add(unique_id)
193
+
194
+ return (True, None, None)
195
+
196
+ def recursive_will_execute(prompt, outputs, current_item, memo={}):
197
+ unique_id = current_item
198
+
199
+ if unique_id in memo:
200
+ return memo[unique_id]
201
+
202
+ inputs = prompt[unique_id]['inputs']
203
+ will_execute = []
204
+ if unique_id in outputs:
205
+ return []
206
+
207
+ for x in inputs:
208
+ input_data = inputs[x]
209
+ if isinstance(input_data, list):
210
+ input_unique_id = input_data[0]
211
+ output_index = input_data[1]
212
+ if input_unique_id not in outputs:
213
+ will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
214
+
215
+ memo[unique_id] = will_execute + [unique_id]
216
+ return memo[unique_id]
217
+
218
+ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
219
+ unique_id = current_item
220
+ inputs = prompt[unique_id]['inputs']
221
+ class_type = prompt[unique_id]['class_type']
222
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
223
+
224
+ is_changed_old = ''
225
+ is_changed = ''
226
+ to_delete = False
227
+ if hasattr(class_def, 'IS_CHANGED'):
228
+ if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
229
+ is_changed_old = old_prompt[unique_id]['is_changed']
230
+ if 'is_changed' not in prompt[unique_id]:
231
+ input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
232
+ if input_data_all is not None:
233
+ try:
234
+ #is_changed = class_def.IS_CHANGED(**input_data_all)
235
+ is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
236
+ prompt[unique_id]['is_changed'] = is_changed
237
+ except:
238
+ to_delete = True
239
+ else:
240
+ is_changed = prompt[unique_id]['is_changed']
241
+
242
+ if unique_id not in outputs:
243
+ return True
244
+
245
+ if not to_delete:
246
+ if is_changed != is_changed_old:
247
+ to_delete = True
248
+ elif unique_id not in old_prompt:
249
+ to_delete = True
250
+ elif inputs == old_prompt[unique_id]['inputs']:
251
+ for x in inputs:
252
+ input_data = inputs[x]
253
+
254
+ if isinstance(input_data, list):
255
+ input_unique_id = input_data[0]
256
+ output_index = input_data[1]
257
+ if input_unique_id in outputs:
258
+ to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
259
+ else:
260
+ to_delete = True
261
+ if to_delete:
262
+ break
263
+ else:
264
+ to_delete = True
265
+
266
+ if to_delete:
267
+ d = outputs.pop(unique_id)
268
+ del d
269
+ return to_delete
270
+
271
+ class PromptExecutor:
272
+ def __init__(self, server):
273
+ self.server = server
274
+ self.reset()
275
+
276
+ def reset(self):
277
+ self.outputs = {}
278
+ self.object_storage = {}
279
+ self.outputs_ui = {}
280
+ self.status_messages = []
281
+ self.success = True
282
+ self.old_prompt = {}
283
+
284
+ def add_message(self, event, data, broadcast: bool):
285
+ self.status_messages.append((event, data))
286
+ if self.server.client_id is not None or broadcast:
287
+ self.server.send_sync(event, data, self.server.client_id)
288
+
289
+ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
290
+ node_id = error["node_id"]
291
+ class_type = prompt[node_id]["class_type"]
292
+
293
+ # First, send back the status to the frontend depending
294
+ # on the exception type
295
+ if isinstance(ex, comfy.model_management.InterruptProcessingException):
296
+ mes = {
297
+ "prompt_id": prompt_id,
298
+ "node_id": node_id,
299
+ "node_type": class_type,
300
+ "executed": list(executed),
301
+ }
302
+ self.add_message("execution_interrupted", mes, broadcast=True)
303
+ else:
304
+ mes = {
305
+ "prompt_id": prompt_id,
306
+ "node_id": node_id,
307
+ "node_type": class_type,
308
+ "executed": list(executed),
309
+
310
+ "exception_message": error["exception_message"],
311
+ "exception_type": error["exception_type"],
312
+ "traceback": error["traceback"],
313
+ "current_inputs": error["current_inputs"],
314
+ "current_outputs": error["current_outputs"],
315
+ }
316
+ self.add_message("execution_error", mes, broadcast=False)
317
+
318
+ # Next, remove the subsequent outputs since they will not be executed
319
+ to_delete = []
320
+ for o in self.outputs:
321
+ if (o not in current_outputs) and (o not in executed):
322
+ to_delete += [o]
323
+ if o in self.old_prompt:
324
+ d = self.old_prompt.pop(o)
325
+ del d
326
+ for o in to_delete:
327
+ d = self.outputs.pop(o)
328
+ del d
329
+
330
+ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
331
+ nodes.interrupt_processing(False)
332
+
333
+ if "client_id" in extra_data:
334
+ self.server.client_id = extra_data["client_id"]
335
+ else:
336
+ self.server.client_id = None
337
+
338
+ self.status_messages = []
339
+ self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
340
+
341
+ with torch.inference_mode():
342
+ #delete cached outputs if nodes don't exist for them
343
+ to_delete = []
344
+ for o in self.outputs:
345
+ if o not in prompt:
346
+ to_delete += [o]
347
+ for o in to_delete:
348
+ d = self.outputs.pop(o)
349
+ del d
350
+ to_delete = []
351
+ for o in self.object_storage:
352
+ if o[0] not in prompt:
353
+ to_delete += [o]
354
+ else:
355
+ p = prompt[o[0]]
356
+ if o[1] != p['class_type']:
357
+ to_delete += [o]
358
+ for o in to_delete:
359
+ d = self.object_storage.pop(o)
360
+ del d
361
+
362
+ for x in prompt:
363
+ recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
364
+
365
+ current_outputs = set(self.outputs.keys())
366
+ for x in list(self.outputs_ui.keys()):
367
+ if x not in current_outputs:
368
+ d = self.outputs_ui.pop(x)
369
+ del d
370
+
371
+ comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
372
+ self.add_message("execution_cached",
373
+ { "nodes": list(current_outputs) , "prompt_id": prompt_id},
374
+ broadcast=False)
375
+ executed = set()
376
+ output_node_id = None
377
+ to_execute = []
378
+
379
+ for node_id in list(execute_outputs):
380
+ to_execute += [(0, node_id)]
381
+
382
+ while len(to_execute) > 0:
383
+ #always execute the output that depends on the least amount of unexecuted nodes first
384
+ memo = {}
385
+ to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
386
+ output_node_id = to_execute.pop(0)[-1]
387
+
388
+ # This call shouldn't raise anything if there's an error deep in
389
+ # the actual SD code, instead it will report the node where the
390
+ # error was raised
391
+ self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
392
+ if self.success is not True:
393
+ self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
394
+ break
395
+
396
+ for x in executed:
397
+ self.old_prompt[x] = copy.deepcopy(prompt[x])
398
+ self.server.last_node_id = None
399
+ if comfy.model_management.DISABLE_SMART_MEMORY:
400
+ comfy.model_management.unload_all_models()
401
+
402
+
403
+
404
+ def validate_inputs(prompt, item, validated):
405
+ unique_id = item
406
+ if unique_id in validated:
407
+ return validated[unique_id]
408
+
409
+ inputs = prompt[unique_id]['inputs']
410
+ class_type = prompt[unique_id]['class_type']
411
+ obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
412
+
413
+ class_inputs = obj_class.INPUT_TYPES()
414
+ required_inputs = class_inputs['required']
415
+
416
+ errors = []
417
+ valid = True
418
+
419
+ validate_function_inputs = []
420
+ if hasattr(obj_class, "VALIDATE_INPUTS"):
421
+ validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
422
+
423
+ for x in required_inputs:
424
+ if x not in inputs:
425
+ error = {
426
+ "type": "required_input_missing",
427
+ "message": "Required input is missing",
428
+ "details": f"{x}",
429
+ "extra_info": {
430
+ "input_name": x
431
+ }
432
+ }
433
+ errors.append(error)
434
+ continue
435
+
436
+ val = inputs[x]
437
+ info = required_inputs[x]
438
+ type_input = info[0]
439
+ if isinstance(val, list):
440
+ if len(val) != 2:
441
+ error = {
442
+ "type": "bad_linked_input",
443
+ "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
444
+ "details": f"{x}",
445
+ "extra_info": {
446
+ "input_name": x,
447
+ "input_config": info,
448
+ "received_value": val
449
+ }
450
+ }
451
+ errors.append(error)
452
+ continue
453
+
454
+ o_id = val[0]
455
+ o_class_type = prompt[o_id]['class_type']
456
+ r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
457
+ if r[val[1]] != type_input:
458
+ received_type = r[val[1]]
459
+ details = f"{x}, {received_type} != {type_input}"
460
+ error = {
461
+ "type": "return_type_mismatch",
462
+ "message": "Return type mismatch between linked nodes",
463
+ "details": details,
464
+ "extra_info": {
465
+ "input_name": x,
466
+ "input_config": info,
467
+ "received_type": received_type,
468
+ "linked_node": val
469
+ }
470
+ }
471
+ errors.append(error)
472
+ continue
473
+ try:
474
+ r = validate_inputs(prompt, o_id, validated)
475
+ if r[0] is False:
476
+ # `r` will be set in `validated[o_id]` already
477
+ valid = False
478
+ continue
479
+ except Exception as ex:
480
+ typ, _, tb = sys.exc_info()
481
+ valid = False
482
+ exception_type = full_type_name(typ)
483
+ reasons = [{
484
+ "type": "exception_during_inner_validation",
485
+ "message": "Exception when validating inner node",
486
+ "details": str(ex),
487
+ "extra_info": {
488
+ "input_name": x,
489
+ "input_config": info,
490
+ "exception_message": str(ex),
491
+ "exception_type": exception_type,
492
+ "traceback": traceback.format_tb(tb),
493
+ "linked_node": val
494
+ }
495
+ }]
496
+ validated[o_id] = (False, reasons, o_id)
497
+ continue
498
+ else:
499
+ try:
500
+ if type_input == "INT":
501
+ val = int(val)
502
+ inputs[x] = val
503
+ if type_input == "FLOAT":
504
+ val = float(val)
505
+ inputs[x] = val
506
+ if type_input == "STRING":
507
+ val = str(val)
508
+ inputs[x] = val
509
+ except Exception as ex:
510
+ error = {
511
+ "type": "invalid_input_type",
512
+ "message": f"Failed to convert an input value to a {type_input} value",
513
+ "details": f"{x}, {val}, {ex}",
514
+ "extra_info": {
515
+ "input_name": x,
516
+ "input_config": info,
517
+ "received_value": val,
518
+ "exception_message": str(ex)
519
+ }
520
+ }
521
+ errors.append(error)
522
+ continue
523
+
524
+ if len(info) > 1:
525
+ if "min" in info[1] and val < info[1]["min"]:
526
+ error = {
527
+ "type": "value_smaller_than_min",
528
+ "message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
529
+ "details": f"{x}",
530
+ "extra_info": {
531
+ "input_name": x,
532
+ "input_config": info,
533
+ "received_value": val,
534
+ }
535
+ }
536
+ errors.append(error)
537
+ continue
538
+ if "max" in info[1] and val > info[1]["max"]:
539
+ error = {
540
+ "type": "value_bigger_than_max",
541
+ "message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
542
+ "details": f"{x}",
543
+ "extra_info": {
544
+ "input_name": x,
545
+ "input_config": info,
546
+ "received_value": val,
547
+ }
548
+ }
549
+ errors.append(error)
550
+ continue
551
+
552
+ if x not in validate_function_inputs:
553
+ if isinstance(type_input, list):
554
+ if val not in type_input:
555
+ input_config = info
556
+ list_info = ""
557
+
558
+ # Don't send back gigantic lists like if they're lots of
559
+ # scanned model filepaths
560
+ if len(type_input) > 20:
561
+ list_info = f"(list of length {len(type_input)})"
562
+ input_config = None
563
+ else:
564
+ list_info = str(type_input)
565
+
566
+ error = {
567
+ "type": "value_not_in_list",
568
+ "message": "Value not in list",
569
+ "details": f"{x}: '{val}' not in {list_info}",
570
+ "extra_info": {
571
+ "input_name": x,
572
+ "input_config": input_config,
573
+ "received_value": val,
574
+ }
575
+ }
576
+ errors.append(error)
577
+ continue
578
+
579
+ if len(validate_function_inputs) > 0:
580
+ input_data_all = get_input_data(inputs, obj_class, unique_id)
581
+ input_filtered = {}
582
+ for x in input_data_all:
583
+ if x in validate_function_inputs:
584
+ input_filtered[x] = input_data_all[x]
585
+
586
+ #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
587
+ ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
588
+ for x in input_filtered:
589
+ for i, r in enumerate(ret):
590
+ if r is not True:
591
+ details = f"{x}"
592
+ if r is not False:
593
+ details += f" - {str(r)}"
594
+
595
+ error = {
596
+ "type": "custom_validation_failed",
597
+ "message": "Custom validation failed for node",
598
+ "details": details,
599
+ "extra_info": {
600
+ "input_name": x,
601
+ "input_config": info,
602
+ "received_value": val,
603
+ }
604
+ }
605
+ errors.append(error)
606
+ continue
607
+
608
+ if len(errors) > 0 or valid is not True:
609
+ ret = (False, errors, unique_id)
610
+ else:
611
+ ret = (True, [], unique_id)
612
+
613
+ validated[unique_id] = ret
614
+ return ret
615
+
616
+ def full_type_name(klass):
617
+ module = klass.__module__
618
+ if module == 'builtins':
619
+ return klass.__qualname__
620
+ return module + '.' + klass.__qualname__
621
+
622
+ def validate_prompt(prompt):
623
+ outputs = set()
624
+ for x in prompt:
625
+ class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
626
+ if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
627
+ outputs.add(x)
628
+
629
+ if len(outputs) == 0:
630
+ error = {
631
+ "type": "prompt_no_outputs",
632
+ "message": "Prompt has no outputs",
633
+ "details": "",
634
+ "extra_info": {}
635
+ }
636
+ return (False, error, [], [])
637
+
638
+ good_outputs = set()
639
+ errors = []
640
+ node_errors = {}
641
+ validated = {}
642
+ for o in outputs:
643
+ valid = False
644
+ reasons = []
645
+ try:
646
+ m = validate_inputs(prompt, o, validated)
647
+ valid = m[0]
648
+ reasons = m[1]
649
+ except Exception as ex:
650
+ typ, _, tb = sys.exc_info()
651
+ valid = False
652
+ exception_type = full_type_name(typ)
653
+ reasons = [{
654
+ "type": "exception_during_validation",
655
+ "message": "Exception when validating node",
656
+ "details": str(ex),
657
+ "extra_info": {
658
+ "exception_type": exception_type,
659
+ "traceback": traceback.format_tb(tb)
660
+ }
661
+ }]
662
+ validated[o] = (False, reasons, o)
663
+
664
+ if valid is True:
665
+ good_outputs.add(o)
666
+ else:
667
+ logging.error(f"Failed to validate prompt for output {o}:")
668
+ if len(reasons) > 0:
669
+ logging.error("* (prompt):")
670
+ for reason in reasons:
671
+ logging.error(f" - {reason['message']}: {reason['details']}")
672
+ errors += [(o, reasons)]
673
+ for node_id, result in validated.items():
674
+ valid = result[0]
675
+ reasons = result[1]
676
+ # If a node upstream has errors, the nodes downstream will also
677
+ # be reported as invalid, but there will be no errors attached.
678
+ # So don't return those nodes as having errors in the response.
679
+ if valid is not True and len(reasons) > 0:
680
+ if node_id not in node_errors:
681
+ class_type = prompt[node_id]['class_type']
682
+ node_errors[node_id] = {
683
+ "errors": reasons,
684
+ "dependent_outputs": [],
685
+ "class_type": class_type
686
+ }
687
+ logging.error(f"* {class_type} {node_id}:")
688
+ for reason in reasons:
689
+ logging.error(f" - {reason['message']}: {reason['details']}")
690
+ node_errors[node_id]["dependent_outputs"].append(o)
691
+ logging.error("Output will be ignored")
692
+
693
+ if len(good_outputs) == 0:
694
+ errors_list = []
695
+ for o, errors in errors:
696
+ for error in errors:
697
+ errors_list.append(f"{error['message']}: {error['details']}")
698
+ errors_list = "\n".join(errors_list)
699
+
700
+ error = {
701
+ "type": "prompt_outputs_failed_validation",
702
+ "message": "Prompt outputs failed validation",
703
+ "details": errors_list,
704
+ "extra_info": {}
705
+ }
706
+
707
+ return (False, error, list(good_outputs), node_errors)
708
+
709
+ return (True, None, list(good_outputs), node_errors)
710
+
711
+ MAXIMUM_HISTORY_SIZE = 10000
712
+
713
+ class PromptQueue:
714
+ def __init__(self, server):
715
+ self.server = server
716
+ self.mutex = threading.RLock()
717
+ self.not_empty = threading.Condition(self.mutex)
718
+ self.task_counter = 0
719
+ self.queue = []
720
+ self.currently_running = {}
721
+ self.history = {}
722
+ self.flags = {}
723
+ server.prompt_queue = self
724
+
725
+ def put(self, item):
726
+ with self.mutex:
727
+ heapq.heappush(self.queue, item)
728
+ self.server.queue_updated()
729
+ self.not_empty.notify()
730
+
731
+ def get(self, timeout=None):
732
+ with self.not_empty:
733
+ while len(self.queue) == 0:
734
+ self.not_empty.wait(timeout=timeout)
735
+ if timeout is not None and len(self.queue) == 0:
736
+ return None
737
+ item = heapq.heappop(self.queue)
738
+ i = self.task_counter
739
+ self.currently_running[i] = copy.deepcopy(item)
740
+ self.task_counter += 1
741
+ self.server.queue_updated()
742
+ return (item, i)
743
+
744
+ class ExecutionStatus(NamedTuple):
745
+ status_str: Literal['success', 'error']
746
+ completed: bool
747
+ messages: List[str]
748
+
749
+ def task_done(self, item_id, outputs,
750
+ status: Optional['PromptQueue.ExecutionStatus']):
751
+ with self.mutex:
752
+ prompt = self.currently_running.pop(item_id)
753
+ if len(self.history) > MAXIMUM_HISTORY_SIZE:
754
+ self.history.pop(next(iter(self.history)))
755
+
756
+ status_dict: Optional[dict] = None
757
+ if status is not None:
758
+ status_dict = copy.deepcopy(status._asdict())
759
+
760
+ self.history[prompt[1]] = {
761
+ "prompt": prompt,
762
+ "outputs": copy.deepcopy(outputs),
763
+ 'status': status_dict,
764
+ }
765
+ self.server.queue_updated()
766
+
767
+ def get_current_queue(self):
768
+ with self.mutex:
769
+ out = []
770
+ for x in self.currently_running.values():
771
+ out += [x]
772
+ return (out, copy.deepcopy(self.queue))
773
+
774
+ def get_tasks_remaining(self):
775
+ with self.mutex:
776
+ return len(self.queue) + len(self.currently_running)
777
+
778
+ def wipe_queue(self):
779
+ with self.mutex:
780
+ self.queue = []
781
+ self.server.queue_updated()
782
+
783
+ def delete_queue_item(self, function):
784
+ with self.mutex:
785
+ for x in range(len(self.queue)):
786
+ if function(self.queue[x]):
787
+ if len(self.queue) == 1:
788
+ self.wipe_queue()
789
+ else:
790
+ self.queue.pop(x)
791
+ heapq.heapify(self.queue)
792
+ self.server.queue_updated()
793
+ return True
794
+ return False
795
+
796
+ def get_history(self, prompt_id=None, max_items=None, offset=-1):
797
+ with self.mutex:
798
+ if prompt_id is None:
799
+ out = {}
800
+ i = 0
801
+ if offset < 0 and max_items is not None:
802
+ offset = len(self.history) - max_items
803
+ for k in self.history:
804
+ if i >= offset:
805
+ out[k] = self.history[k]
806
+ if max_items is not None and len(out) >= max_items:
807
+ break
808
+ i += 1
809
+ return out
810
+ elif prompt_id in self.history:
811
+ return {prompt_id: copy.deepcopy(self.history[prompt_id])}
812
+ else:
813
+ return {}
814
+
815
+ def wipe_history(self):
816
+ with self.mutex:
817
+ self.history = {}
818
+
819
+ def delete_history_item(self, id_to_delete):
820
+ with self.mutex:
821
+ self.history.pop(id_to_delete, None)
822
+
823
+ def set_flag(self, name, data):
824
+ with self.mutex:
825
+ self.flags[name] = data
826
+ self.not_empty.notify()
827
+
828
+ def get_flags(self, reset=True):
829
+ with self.mutex:
830
+ if reset:
831
+ ret = self.flags
832
+ self.flags = {}
833
+ return ret
834
+ else:
835
+ return self.flags.copy()