Add existing AxesSubplot object to another Subplot

I followed a tutorial from [blog.paperspace.com](https://blog.paperspace.com/model-interpretability-and-understanding-for-pytorch-using-captum/) where I used Captum to visualize what my model learns. Until now it works great and I can get some visualization that makes sense.

However, I would love to add the output (plt_fig, plt_axis) of the following function to an existing subplot:

plt_fig, plt_axis = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

Currently the output of this function is the following image:

cat

However, I would like to be able to add the subplot to an existing subplot, e.g:

Is this possible? I have tried multiple things, but none has worked until now. I am aware of that one solution is to change the source code, however, I am not very interested in that since I would like to keep the source code as it is.

Is there some other way? I even tried to save the image as png and then load it again, however, I dont like that solution :slight_smile:

It would also be great if I could just take first image in the subplot and add it.

I do not think it is possible to do what you describe. However, if I’m reading the Captum docs right, you should be able to use vizualize_image_attr, and pass a figure and axes of your choice to draw on:

Thanks for your reply @rcomer . How would you possible do that? Could you give an example?

I have not tested, but it looks like you should be able to translate your example into something like:

import matplotlib.pyplot as plt

fig, ax_array = plt.subplots(2, 4)

for method, sign, ax in zip(["original_image", "heat_map"],
                            ["all", "positive"], ax_array.flat):
    viz.visualize_image_attr(
        np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
        np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
        method, sign, (fig, ax), cmap=default_cmap, show_colorbar=method != "original image")
1 Like

That worked! Thanks!

1 Like