1

I want to plot a sequence of three colormaps in a 3D space, with a line crossing all the planes of the colormaps, as shown in the figure below.

https://i.sstatic.net/65yOib6B.png

To do that, I am using mpl.plot_surface to generate the planes and LinearSegmentedColormap to create a colormap that transitions from transparent to a specific color.

However, when I plot the figure, a gray grid appears on my plot. How can I remove it? Ideally, the blue shade would appear on a completely transparent plane, but a lighter color could also work.

Here is the code I used to generate the plot:

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

# Testing Data
sigma = 1.0
mu = np.linspace(0,2, 10)

x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)

Z = []
for m in mu:
    Z.append(np.exp(-((X - m)**2 + (Y - m)**2) / (2 * sigma**2)))

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in [0, 5, -1]:
    cmap = LinearSegmentedColormap.from_list('custom_blue', [(1, 1, 1, 0), (0, 0, 1, 1)])
    wmap = cmap(Z[i]/Z[i].max())
    ax.plot_surface(mu[i] * np.ones(X.shape), X, Y,facecolors=wmap, alpha=1, antialiased=True, edgecolor='none')

loc_max_x = []
loc_max_y = []
for i in range(len(mu)):
    loc_x = np.where(Z[i] == Z[i].max())[0][0]
    loc_y = np.where(Z[i] == Z[i].max())[1][0]

    loc_max_x.append(loc_x)
    loc_max_y.append(loc_y)

ax.plot(mu, x[loc_max_x], y[loc_max_y], color='r')
ax.set_box_aspect((3.4, 1, 1))

plt.savefig('3dplot.png', dpi=300)
plt.show()
2
  • 2
    I'm not sure that can be removed. Matplotlib isn't exactly the best library for 3D plotting.
    – jared
    Commented Oct 10, 2024 at 1:08
  • I agree with @jared, see my answer below Commented Oct 10, 2024 at 8:04

1 Answer 1

2

I think there's nothing you could have done better in matplotlib, great job!

I think to solve your problem, it is better to change the library and approach your problem using plotly.

Please see my code:

import plotly.graph_objects as go
import numpy as np


# Testing Data
sigma = 1.0
mu = np.linspace(0, 2, 10)

x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)

Z = []
for m in mu:
    Z.append(np.exp(-((X - m)**2 + (Y - m)**2) / (2 * sigma**2)))

fig = go.Figure()

colorscale = [[0, 'rgba(255, 255, 255, 0)'], [1, 'rgba(0, 0, 255, 1)']]  # colorscale = transparent to blue

#plot the surfaces 
for i in [0, 5, -1]:
    fig.add_trace(go.Surface(
        x=mu[i] * np.ones(X.shape), y=X, z=Y, surfacecolor=Z[i], 
        colorscale=colorscale, cmin=0, cmax=Z[i].max(),
        showscale=False, opacity=1))

#plot the line crossing the surfaces
loc_max_x = []
loc_max_y = []
for i in range(len(mu)):
    loc_x = np.where(Z[i] == Z[i].max())[0][0]
    loc_y = np.where(Z[i] == Z[i].max())[1][0]
    loc_max_x.append(loc_x)
    loc_max_y.append(loc_y)

#add the line trace
fig.add_trace(go.Scatter3d(
    x=mu, y=x[loc_max_x], z=y[loc_max_y], 
    mode='lines', line=dict(color='red', width=5)))

fig.update_layout(scene_aspectmode='manual',
                  scene_aspectratio=dict(x=3.4, y=1, z=1),
                  scene=dict(xaxis_title='mu', yaxis_title='X', zaxis_title='Y'))

fig.show()

which results this plot: enter image description here

1
  • That's really a great result. Apparently it's time to move to a newer library. Thank you very much! Commented Oct 11, 2024 at 14:25

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.