{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "three_d_plot.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyM+fVOqn1BZGxcX2KBPGGIE",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "0DzKR2vNKAOn"
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import plotly.graph_objects as go\n",
"\n",
"\n",
"def get_data():\n",
" x = np.linspace(-1, 1, 10)\n",
" y = np.linspace(-1, 1, 10)\n",
"\n",
" X, Y = np.meshgrid(x, y)\n",
"\n",
" Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],\n",
" [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,\n",
" 3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])\n",
" return X, Y, Z\n",
"\n",
"\n",
"# One-color arrows & arrowheads\n",
"colorscale = [\n",
" [0, \"rgb(84,48,5)\"],\n",
" [1, \"rgb(84,48,5)\"],\n",
"]\n",
"\n",
"X, Y, Z = get_data()\n",
"\n",
"data = {\n",
" key: {\n",
" \"min\": v.min(),\n",
" \"max\": v.max(),\n",
" \"mid\": (v.max() + v.min()) / 2,\n",
" \"range\": v.max() - v.min(),\n",
" }\n",
" for (key, v) in dict(x=X, y=Y, z=Z).items()\n",
"}\n",
"\n"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 517
},
"id": "eJTMQrwHKDDL",
"outputId": "8d7e0fcb-3a7c-4e80-a7e2-f16c0e04dac1"
},
"source": [
"def get_arrow(axisname=\"x\"):\n",
"\n",
" # Create arrow body\n",
" body = go.Scatter3d(\n",
" marker=dict(size=1, color=colorscale[0][1]),\n",
" line=dict(color=colorscale[0][1], width=3),\n",
" showlegend=False, # hide the legend\n",
" )\n",
"\n",
" head = go.Cone(\n",
" sizeref=0.1,\n",
" autocolorscale=None,\n",
" colorscale=colorscale,\n",
" showscale=False, # disable additional colorscale for arrowheads\n",
" hovertext=axisname,\n",
" )\n",
" for ax, direction in zip((\"x\", \"y\", \"z\"), (\"u\", \"v\", \"w\")):\n",
" if ax == axisname:\n",
" body[ax] = data[ax][\"min\"], data[ax][\"max\"]\n",
" head[ax] = [data[ax][\"max\"]]\n",
" head[direction] = [1]\n",
" else:\n",
" body[ax] = data[ax][\"mid\"], data[ax][\"mid\"]\n",
" head[ax] = [data[ax][\"mid\"]]\n",
" head[direction] = [0]\n",
"\n",
" return [body, head]\n",
"\n",
"\n",
"def add_axis_arrows(fig):\n",
" for ax in (\"x\", \"y\", \"z\"):\n",
" for item in get_arrow(ax):\n",
" fig.add_trace(item)\n",
"\n",
"\n",
"def get_annotation_for_ax(ax):\n",
" d = dict(showarrow=False, text=ax, xanchor=\"left\", font=dict(color=\"#1f1f1f\"))\n",
" for ax_ in (\"x\", \"y\", \"z\"):\n",
" if ax_ == ax:\n",
" d[ax_] = data[ax][\"max\"] - data[ax][\"range\"] * 0.05\n",
" else:\n",
" d[ax_] = data[ax_][\"mid\"]\n",
"\n",
" if ax in {\"x\", \"y\"}:\n",
" d[\"xshift\"] = 15\n",
"\n",
" return d\n",
"\n",
"\n",
"def get_axis_names():\n",
" return [get_annotation_for_ax(ax) for ax in (\"x\", \"y\", \"z\")]\n",
"\n",
"\n",
"def get_scene_axis(axisname=\"x\"):\n",
"\n",
" return dict(\n",
" title=\"\", # remove axis label (x,y,z)\n",
" showbackground=False,\n",
" visible=True,\n",
" showticklabels=False, # hide numeric values of axes\n",
" showgrid=True, # Show box around plot\n",
" gridcolor=\"grey\", # Box color\n",
" tickvals=[data[axisname][\"min\"], data[axisname][\"max\"]], # Set box limits\n",
" range=[\n",
" data[axisname][\"min\"],\n",
" data[axisname][\"max\"],\n",
" ], # Prevent extra lines around box\n",
" )\n",
"\n",
"\n",
"fig = go.Figure(\n",
" \n",
" layout=dict(\n",
" title=\"surface\",\n",
" autosize=True,\n",
" width=700,\n",
" height=500,\n",
" margin=dict(l=20, r=20, b=25, t=25),\n",
" scene=dict(\n",
" xaxis=get_scene_axis(\"x\"),\n",
" yaxis=get_scene_axis(\"y\"),\n",
" zaxis=get_scene_axis(\"z\"),\n",
" annotations=get_axis_names(),\n",
" ),\n",
" ),\n",
")\n",
"\n",
"add_axis_arrows(fig)\n",
"\n",
"N = 1000\n",
"t = np.linspace(0, 1, 100)\n",
"y = np.sin(t)\n",
"t = np.linspace(0, 10, 50)\n",
"x, y, z = np.cos(t), np.sin(t), t\n",
"t = np.zeros_like((t))+1.0\n",
"z = t\n",
"\n",
"fig.add_trace(go.Scatter3d(x=x, y=y, z=z))\n",
"fig.show()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"