fabio-deep commited on
Commit
e86be9f
·
1 Parent(s): 566b916
Files changed (5) hide show
  1. .gitignore +1 -0
  2. README.md +18 -2
  3. app.py +58 -59
  4. pgm/layers.py +4 -37
  5. requirements.txt +1 -1
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .vscode
 
2
  __pycache__
3
  *.pyc
 
1
  .vscode
2
+ .gradio
3
  __pycache__
4
  *.pyc
README.md CHANGED
@@ -4,11 +4,27 @@ emoji: 🌖
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  duplicated_from: fabio-deep/counterfactuals
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  duplicated_from: fabio-deep/counterfactuals
12
  ---
13
 
14
+ Code for the **ICML 2023** paper:
15
+
16
+ [**High Fidelity Image Counterfactuals with Probabilistic Causal Models**](https://arxiv.org/abs/2306.15764)
17
+
18
+ Fabio De Sousa Ribeiro<sup>1</sup>, Tian Xia<sup>1</sup>, Miguel Monteiro<sup>1</sup>, Nick Pawlowski<sup>2</sup>, Ben Glocker<sup>1</sup>\
19
+ <sup>1</sup>Imperial College London, <sup>2</sup>Microsoft Research Cambridge, UK
20
+
21
+ ```
22
+ @misc{ribeiro2023high,
23
+ title={High Fidelity Image Counterfactuals with Probabilistic Causal Models},
24
+ author={Fabio De Sousa Ribeiro and Tian Xia and Miguel Monteiro and Nick Pawlowski and Ben Glocker},
25
+ year={2023},
26
+ eprint={2306.15764},
27
+ archivePrefix={arXiv},
28
+ primaryClass={cs.LG}
29
+ }
30
+ ```
app.py CHANGED
@@ -401,36 +401,47 @@ def infer_chest_cf(*args):
401
  return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
402
 
403
 
404
- with gr.Blocks(theme=gr.themes.Default()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  with gr.Tabs():
406
  with gr.TabItem("Morpho-MNIST") as mnist_tab:
407
  mnist_id = gr.Textbox(value=mnist_tab.label, visible=False)
408
-
409
- with gr.Row().style(equal_height=True):
410
  idx = gr.Number(value=0, visible=False)
411
  with gr.Column(scale=1, min_width=200):
412
- x = gr.Image(label="Observation", interactive=False).style(
413
- height=HEIGHT
414
- )
415
  with gr.Column(scale=1, min_width=200):
416
- cf_x = gr.Image(label="Counterfactual", interactive=False).style(
417
- height=HEIGHT
418
- )
419
  with gr.Column(scale=1, min_width=200):
420
- cf_x_std = gr.Image(
421
- label="Counterfactual Uncertainty", interactive=False
422
- ).style(height=HEIGHT)
423
  with gr.Column(scale=1, min_width=200):
424
- effect = gr.Image(
425
- label="Direct Causal Effect", interactive=False
426
- ).style(height=HEIGHT)
427
- with gr.Row().style(equal_height=True):
428
  with gr.Column(scale=1.75):
429
  gr.Markdown(
430
  "**Intervention**"
431
- + 20 * "&ensp;"
432
- + "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
433
- + "&ensp; | &ensp; Hint: try 90% zoom"
434
  )
435
  with gr.Column():
436
  do_y = gr.Checkbox(label="do(digit)", value=False)
@@ -460,38 +471,34 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
460
  submit = gr.Button("Submit", variant="primary")
461
  with gr.Column(scale=1):
462
  gr.Markdown("### &nbsp;")
463
- causal_graph = gr.Image(
464
- label="Causal Graph", interactive=False
465
- ).style(height=300)
466
 
467
  with gr.TabItem("Brain MRI") as brain_tab:
468
  brain_id = gr.Textbox(value=brain_tab.label, visible=False)
469
 
470
- with gr.Row().style(equal_height=True):
471
  idx_brain = gr.Number(value=0, visible=False)
472
  with gr.Column(scale=1, min_width=200):
473
- x_brain = gr.Image(label="Observation", interactive=False).style(
474
- height=HEIGHT
475
- )
476
  with gr.Column(scale=1, min_width=200):
477
- cf_x_brain = gr.Image(
478
- label="Counterfactual", interactive=False
479
- ).style(height=HEIGHT)
480
  with gr.Column(scale=1, min_width=200):
481
  cf_x_std_brain = gr.Image(
482
- label="Counterfactual Uncertainty", interactive=False
483
- ).style(height=HEIGHT)
484
  with gr.Column(scale=1, min_width=200):
485
  effect_brain = gr.Image(
486
- label="Direct Causal Effect", interactive=False
487
- ).style(height=HEIGHT)
488
  with gr.Row():
489
  with gr.Column(scale=2.55):
490
  gr.Markdown(
491
  "**Intervention**"
492
- + 20 * "&ensp;"
493
- + "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
494
- + "&ensp; | &ensp; Hint: try 90% zoom"
495
  )
496
  with gr.Row():
497
  with gr.Column(min_width=200):
@@ -543,41 +550,33 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
543
  submit_brain = gr.Button("Submit", variant="primary")
544
  with gr.Column(scale=1):
545
  # gr.Markdown("### &nbsp;")
546
- causal_graph_brain = gr.Image(
547
- label="Causal Graph", interactive=False
548
- ).style(height=340)
549
 
550
  with gr.TabItem("Chest X-ray") as chest_tab:
551
  chest_id = gr.Textbox(value=chest_tab.label, visible=False)
552
 
553
- with gr.Row().style(equal_height=True):
554
  idx_chest = gr.Number(value=0, visible=False)
555
  with gr.Column(scale=1, min_width=200):
556
- x_chest = gr.Image(label="Observation", interactive=False).style(
557
- height=HEIGHT
558
- )
559
  with gr.Column(scale=1, min_width=200):
560
- cf_x_chest = gr.Image(
561
- label="Counterfactual", interactive=False
562
- ).style(height=HEIGHT)
563
  with gr.Column(scale=1, min_width=200):
564
- cf_x_std_chest = gr.Image(
565
- label="Counterfactual Uncertainty", interactive=False
566
- ).style(height=HEIGHT)
567
  with gr.Column(scale=1, min_width=200):
568
- effect_chest = gr.Image(
569
- label="Direct Causal Effect", interactive=False
570
- ).style(height=HEIGHT)
571
 
572
  with gr.Row():
573
  with gr.Column(scale=2.55):
574
  gr.Markdown(
575
  "**Intervention**"
576
- + 20 * "&ensp;"
577
- + "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
578
- + "&ensp; | &ensp; Hint: try 90% zoom"
579
  )
580
- with gr.Row().style(equal_height=True):
581
  with gr.Column(min_width=200):
582
  do_f_chest = gr.Checkbox(label="do(disease)", value=False)
583
  f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
@@ -603,9 +602,9 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
603
  submit_chest = gr.Button("Submit", variant="primary")
604
  with gr.Column(scale=1):
605
  # gr.Markdown("### &nbsp;")
606
- causal_graph_chest = gr.Image(
607
- label="Causal Graph", interactive=False
608
- ).style(height=345)
609
 
610
  # morphomnist
611
  do = [do_t, do_i, do_y]
 
401
  return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
402
 
403
 
404
+ js_func = """
405
+ function refresh() {
406
+ const url = new URL(window.location);
407
+ if (url.searchParams.get('__theme') !== 'light') {
408
+ url.searchParams.set('__theme', 'light');
409
+ window.location.href = url.href;
410
+ }
411
+ }
412
+ """
413
+
414
+ with gr.Blocks(
415
+ # theme=gr.themes.Default(),
416
+ theme="shivi/calm_seafoam",
417
+ js=js_func
418
+ ) as demo:
419
+ img_cfg = dict(
420
+ interactive=False,
421
+ height=HEIGHT,
422
+ show_download_button=False,
423
+ show_fullscreen_button=False
424
+ )
425
  with gr.Tabs():
426
  with gr.TabItem("Morpho-MNIST") as mnist_tab:
427
  mnist_id = gr.Textbox(value=mnist_tab.label, visible=False)
428
+ with gr.Row(equal_height=True):
 
429
  idx = gr.Number(value=0, visible=False)
430
  with gr.Column(scale=1, min_width=200):
431
+ x = gr.Image(label="Observation", **img_cfg)
 
 
432
  with gr.Column(scale=1, min_width=200):
433
+ cf_x = gr.Image(label="Counterfactual", **img_cfg)
 
 
434
  with gr.Column(scale=1, min_width=200):
435
+ cf_x_std = gr.Image(label="Uncertainty", **img_cfg)
 
 
436
  with gr.Column(scale=1, min_width=200):
437
+ effect = gr.Image(label="Causal Effect", **img_cfg)
438
+ with gr.Row(equal_height=True):
 
 
439
  with gr.Column(scale=1.75):
440
  gr.Markdown(
441
  "**Intervention**"
442
+ + 22 * "&ensp;"
443
+ + "[Paper](https://proceedings.mlr.press/v202/de-sousa-ribeiro23a.html) &ensp; | &ensp; [Code](https://github.com/biomedia-mira/causal-gen)"
444
+ # + "&ensp; | &ensp; Hint: try 90% zoom"
445
  )
446
  with gr.Column():
447
  do_y = gr.Checkbox(label="do(digit)", value=False)
 
471
  submit = gr.Button("Submit", variant="primary")
472
  with gr.Column(scale=1):
473
  gr.Markdown("### &nbsp;")
474
+ img_cfg["height"] = 300
475
+ causal_graph = gr.Image(label="Causal Graph", **img_cfg)
476
+ img_cfg["height"] = HEIGHT
477
 
478
  with gr.TabItem("Brain MRI") as brain_tab:
479
  brain_id = gr.Textbox(value=brain_tab.label, visible=False)
480
 
481
+ with gr.Row(equal_height=True):
482
  idx_brain = gr.Number(value=0, visible=False)
483
  with gr.Column(scale=1, min_width=200):
484
+ x_brain = gr.Image(label="Observation", **img_cfg)
 
 
485
  with gr.Column(scale=1, min_width=200):
486
+ cf_x_brain = gr.Image(label="Counterfactual", **img_cfg)
 
 
487
  with gr.Column(scale=1, min_width=200):
488
  cf_x_std_brain = gr.Image(
489
+ label="Uncertainty", **img_cfg
490
+ )
491
  with gr.Column(scale=1, min_width=200):
492
  effect_brain = gr.Image(
493
+ label="Causal Effect", **img_cfg
494
+ )
495
  with gr.Row():
496
  with gr.Column(scale=2.55):
497
  gr.Markdown(
498
  "**Intervention**"
499
+ + 22 * "&ensp;"
500
+ + "[Paper](https://proceedings.mlr.press/v202/de-sousa-ribeiro23a.html) &ensp; | &ensp; [Code](https://github.com/biomedia-mira/causal-gen)"
501
+ # + "&ensp; | &ensp; Hint: try 90% zoom"
502
  )
503
  with gr.Row():
504
  with gr.Column(min_width=200):
 
550
  submit_brain = gr.Button("Submit", variant="primary")
551
  with gr.Column(scale=1):
552
  # gr.Markdown("### &nbsp;")
553
+ img_cfg["height"] = 340
554
+ causal_graph_brain = gr.Image(label="Causal Graph", **img_cfg)
555
+ img_cfg["height"] = HEIGHT
556
 
557
  with gr.TabItem("Chest X-ray") as chest_tab:
558
  chest_id = gr.Textbox(value=chest_tab.label, visible=False)
559
 
560
+ with gr.Row(equal_height=True):
561
  idx_chest = gr.Number(value=0, visible=False)
562
  with gr.Column(scale=1, min_width=200):
563
+ x_chest = gr.Image(label="Observation", **img_cfg)
 
 
564
  with gr.Column(scale=1, min_width=200):
565
+ cf_x_chest = gr.Image(label="Counterfactual", **img_cfg)
 
 
566
  with gr.Column(scale=1, min_width=200):
567
+ cf_x_std_chest = gr.Image(label="Uncertainty", **img_cfg)
 
 
568
  with gr.Column(scale=1, min_width=200):
569
+ effect_chest = gr.Image(label="Causal Effect", **img_cfg)
 
 
570
 
571
  with gr.Row():
572
  with gr.Column(scale=2.55):
573
  gr.Markdown(
574
  "**Intervention**"
575
+ + 22 * "&ensp;"
576
+ + "[Paper](https://proceedings.mlr.press/v202/de-sousa-ribeiro23a.html) &ensp; | &ensp; [Code](https://github.com/biomedia-mira/causal-gen)"
577
+ # + "&ensp; | &ensp; Hint: try 90% zoom"
578
  )
579
+ with gr.Row(equal_height=True):
580
  with gr.Column(min_width=200):
581
  do_f_chest = gr.Checkbox(label="do(disease)", value=False)
582
  f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
 
602
  submit_chest = gr.Button("Submit", variant="primary")
603
  with gr.Column(scale=1):
604
  # gr.Markdown("### &nbsp;")
605
+ img_cfg["height"] = 345
606
+ causal_graph_chest = gr.Image(label="Causal Graph", **img_cfg)
607
+ img_cfg["height"] = HEIGHT
608
 
609
  # morphomnist
610
  do = [do_t, do_i, do_y]
pgm/layers.py CHANGED
@@ -91,7 +91,7 @@ class CNN(nn.Module):
91
 
92
 
93
  class ArgMaxGumbelMax(Transform):
94
- r"""ArgMax as Transform, but inv conditioned on logits"""
95
 
96
  def __init__(self, logits, event_dim=0, cache_size=0):
97
  super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
@@ -106,9 +106,6 @@ class ArgMaxGumbelMax(Transform):
106
  return self._event_dim
107
 
108
  def __call__(self, gumbels):
109
- """
110
- Computes the forward transform
111
- """
112
  assert self.logits != None, "Logits not defined."
113
 
114
  if self._cache_size == 0:
@@ -118,20 +115,12 @@ class ArgMaxGumbelMax(Transform):
118
  return y
119
 
120
  def _call(self, gumbels):
121
- """
122
- Abstract method to compute forward transformation.
123
- """
124
  assert self.logits != None, "Logits not defined."
125
  y = gumbels + self.logits
126
- # print(f'y: {y}')
127
- # print(f'logits: {self.logits}')
128
  return y.argmax(-1, keepdim=True)
129
 
130
  @property
131
  def domain(self):
132
- """ "
133
- Domain of input(gumbel variables), Real
134
- """
135
  if self.event_dim == 0:
136
  return constraints.real
137
  return constraints.independent(constraints.real, self.event_dim)
@@ -148,18 +137,14 @@ class ArgMaxGumbelMax(Transform):
148
  def inv(self, k):
149
  """Infer the gumbels noises given k and logits."""
150
  assert self.logits != None, "Logits not defined."
151
-
152
  uniforms = torch.rand(
153
  self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
154
  )
155
  gumbels = -((-(uniforms.log())).log())
156
- # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
157
  # (batch_size, num_classes) mask to select kth class
158
- # print(f'k : {k.size()}')
159
  mask = F.one_hot(
160
  k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
161
  )
162
- # print(f'mask: {mask.size()}, {mask.dtype}')
163
  # (batch_size, 1) select topgumbel for truncation of other classes
164
  topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
165
  mask * self.logits
@@ -173,41 +158,25 @@ class ArgMaxGumbelMax(Transform):
173
  return epsilons
174
 
175
  def log_abs_det_jacobian(self, x, y):
176
- """We use the log_abs_det_jacobian to account for the categorical prob
177
- x: Gumbels; y: argmax(x+logits)
178
- P(y) = softmax
179
- """
180
- # print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}")
181
- # print(f'y: {y.size()} ')
182
- # print(f"log_abs_det_jacobian: {self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1).size()}")
183
  return -self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1)
184
 
185
 
186
  class ConditionalGumbelMax(ConditionalTransformModule):
187
- r"""Given gumbels+logits, output the OneHot Categorical"""
188
-
189
  def __init__(self, context_nn, event_dim=0, **kwargs):
190
- # The logits_nn which predict the logits given ages:
191
  super().__init__(**kwargs)
192
  self.context_nn = context_nn
193
  self.event_dim = event_dim
194
 
195
  def condition(self, context):
196
- """Given context (age), output the Categorical results"""
197
- logits = self.context_nn(
198
- context
199
- ) # The logits for calculating argmax(Gumbel + logits)
200
  return ArgMaxGumbelMax(logits)
201
 
202
  def _logits(self, context):
203
- """Return logits given context"""
204
  return self.context_nn(context)
205
 
206
  @property
207
  def domain(self):
208
- """ "
209
- Domain of input(gumbel variables), Real
210
- """
211
  if self.event_dim == 0:
212
  return constraints.real
213
  return constraints.independent(constraints.real, self.event_dim)
@@ -224,6 +193,7 @@ class ConditionalGumbelMax(ConditionalTransformModule):
224
 
225
  class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
226
  r"""Define a TransformedDistribution class for Gumbel max"""
 
227
  arg_constraints: Dict[str, constraints.Constraint] = {}
228
 
229
  def log_prob(self, value):
@@ -231,7 +201,6 @@ class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributio
231
  We do not use the log_prob() of the base Gumbel distribution, because the likelihood for
232
  each class for Gumbel Max sampling is determined by the logits.
233
  """
234
- # print("This happens")
235
  if self._validate_args:
236
  self._validate_sample(value)
237
  event_dim = len(self.event_shape)
@@ -245,7 +214,6 @@ class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributio
245
  event_dim - transform.domain.event_dim,
246
  )
247
  y = x
248
- # print(f"log_prob: {log_prob.size()}")
249
  return log_prob
250
 
251
 
@@ -253,7 +221,6 @@ class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribu
253
  def condition(self, context):
254
  base_dist = self.base_dist.condition(context)
255
  transforms = [t.condition(context) for t in self.transforms]
256
- # return TransformedDistribution(base_dist, transforms)
257
  return TransformedDistributionGumbelMax(base_dist, transforms)
258
 
259
  def clear_cache(self):
 
91
 
92
 
93
  class ArgMaxGumbelMax(Transform):
94
+ r"""ArgMax as Transform, with inverse conditioned on logits"""
95
 
96
  def __init__(self, logits, event_dim=0, cache_size=0):
97
  super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
 
106
  return self._event_dim
107
 
108
  def __call__(self, gumbels):
 
 
 
109
  assert self.logits != None, "Logits not defined."
110
 
111
  if self._cache_size == 0:
 
115
  return y
116
 
117
  def _call(self, gumbels):
 
 
 
118
  assert self.logits != None, "Logits not defined."
119
  y = gumbels + self.logits
 
 
120
  return y.argmax(-1, keepdim=True)
121
 
122
  @property
123
  def domain(self):
 
 
 
124
  if self.event_dim == 0:
125
  return constraints.real
126
  return constraints.independent(constraints.real, self.event_dim)
 
137
  def inv(self, k):
138
  """Infer the gumbels noises given k and logits."""
139
  assert self.logits != None, "Logits not defined."
 
140
  uniforms = torch.rand(
141
  self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
142
  )
143
  gumbels = -((-(uniforms.log())).log())
 
144
  # (batch_size, num_classes) mask to select kth class
 
145
  mask = F.one_hot(
146
  k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
147
  )
 
148
  # (batch_size, 1) select topgumbel for truncation of other classes
149
  topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
150
  mask * self.logits
 
158
  return epsilons
159
 
160
  def log_abs_det_jacobian(self, x, y):
161
+ """bit hacky for now"""
 
 
 
 
 
 
162
  return -self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1)
163
 
164
 
165
  class ConditionalGumbelMax(ConditionalTransformModule):
 
 
166
  def __init__(self, context_nn, event_dim=0, **kwargs):
 
167
  super().__init__(**kwargs)
168
  self.context_nn = context_nn
169
  self.event_dim = event_dim
170
 
171
  def condition(self, context):
172
+ logits = self.context_nn(context)
 
 
 
173
  return ArgMaxGumbelMax(logits)
174
 
175
  def _logits(self, context):
 
176
  return self.context_nn(context)
177
 
178
  @property
179
  def domain(self):
 
 
 
180
  if self.event_dim == 0:
181
  return constraints.real
182
  return constraints.independent(constraints.real, self.event_dim)
 
193
 
194
  class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
195
  r"""Define a TransformedDistribution class for Gumbel max"""
196
+
197
  arg_constraints: Dict[str, constraints.Constraint] = {}
198
 
199
  def log_prob(self, value):
 
201
  We do not use the log_prob() of the base Gumbel distribution, because the likelihood for
202
  each class for Gumbel Max sampling is determined by the logits.
203
  """
 
204
  if self._validate_args:
205
  self._validate_sample(value)
206
  event_dim = len(self.event_shape)
 
214
  event_dim - transform.domain.event_dim,
215
  )
216
  y = x
 
217
  return log_prob
218
 
219
 
 
221
  def condition(self, context):
222
  base_dist = self.base_dist.condition(context)
223
  transforms = [t.condition(context) for t in self.transforms]
 
224
  return TransformedDistributionGumbelMax(base_dist, transforms)
225
 
226
  def clear_cache(self):
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==3.27.0
2
  matplotlib==3.7.1
3
  networkx==2.8.4
4
  numpy==1.24.3
 
1
+ gradio==5.35.0
2
  matplotlib==3.7.1
3
  networkx==2.8.4
4
  numpy==1.24.3