Skip to content

Commit df4827d

Browse files
committed
format and update target version
1 parent 7649051 commit df4827d

5 files changed

Lines changed: 55 additions & 50 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ build-backend = "poetry.core.masonry.api"
7272

7373
[tool.ruff]
7474
line-length = 88
75-
target-version = "py38"
75+
target-version = "py311"
7676
exclude = ["camera_derivatives.py", "**/*.ipynb"]
7777

7878
[tool.ruff.lint]

spatialmath/base/animate.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import matplotlib.pyplot as plt
55
from matplotlib import animation
66
from matplotlib.axes import Axes
7+
from mpl_toolkits.mplot3d.axes3d import Axes3D
78
import spatialmath.base as smb
89
from collections.abc import Iterable, Iterator
910
from spatialmath.base.types import ArrayLike, SO3Array, SE3Array, SO2Array, SE2Array
@@ -40,7 +41,7 @@ class Animate:
4041

4142
def __init__(
4243
self,
43-
ax: Axes | None = None,
44+
ax: Axes3D | None = None,
4445
dim: ArrayLike | None = None,
4546
projection: str = "ortho",
4647
labels: tuple[str, str, str] = ("X", "Y", "Z"),
@@ -103,9 +104,9 @@ def __init__(
103104
raise ValueError(
104105
f"dim must have 2 or 6 elements, got {dim}. See docstring for details."
105106
)
106-
ax.set_xlim(dim[0:2])
107-
ax.set_ylim(dim[2:4])
108-
ax.set_zlim(dim[4:])
107+
ax.set_xlim((dim[0], dim[1]))
108+
ax.set_ylim((dim[2], dim[3]))
109+
ax.set_zlim((dim[4], dim[5]))
109110

110111
self.ax = ax
111112

@@ -528,6 +529,9 @@ def __init__(
528529
Will setup to plot into an existing or a new Axes3D instance.
529530
530531
"""
532+
assert not isinstance(dims, float)
533+
dims = np.array(dims)
534+
531535
self.trajectory = None
532536
self.displaylist = []
533537

@@ -547,8 +551,8 @@ def __init__(
547551
if dims is not None:
548552
if len(dims) == 2:
549553
dims = dims * 2
550-
axes.set_xlim(dims[0:2])
551-
axes.set_ylim(dims[2:4])
554+
axes.set_xlim((dims[0], dims[1]))
555+
axes.set_ylim((dims[2], dims[3]))
552556
# ax.set_aspect('equal')
553557

554558
self.ax = axes
@@ -601,7 +605,7 @@ def trplot2(
601605
def run(
602606
self,
603607
movie: str | bool | None = None,
604-
axes: plt.Axes | None = None,
608+
axes: Axes | None = None,
605609
repeat: bool = False,
606610
interval: int = 50,
607611
nframes: int = 100,
@@ -832,9 +836,6 @@ def __init__(self, anim, h, x, y):
832836

833837
def draw(self, T):
834838
p = T @ self.p
835-
# x2, y2, _ = proj3d.proj_transform(
836-
# p[0], p[1], p[2], self.anim.ax.get_proj())
837-
# self.h.set_position((x2, y2))
838839
self.h.set_position((p[0], p[1]))
839840

840841
def text(self, x, y, *args, **kwargs):
@@ -879,29 +880,6 @@ def set_ylabel(self, *args, **kwargs):
879880

880881

881882
if __name__ == "__main__":
882-
# from spatialmath import UnitQuaternion
883-
# from spatialmath.base import tranimate, r2t
884-
885-
# J = np.array([[2, -1, 0], [-1, 4, 0], [0, 0, 3]])
886-
# dt = 0.05
887-
# def attitude():
888-
# attitude = UnitQuaternion()
889-
# w = 0.2 * np.r_[1, 2, 2].T
890-
# for t in np.arange(0, 3, dt):
891-
# wd = -np.linalg.inv(J) @ (np.cross(w, J @ w))
892-
# w += wd * dt
893-
# attitude.increment(w * dt)
894-
# yield attitude.R
895-
# plt.figure()
896-
# plotvol3(2)
897-
# tranimate(attitude())
898-
899-
# T = smb.rpy2r(0.3, 0.4, 0.5)
900-
# # smb.tranimate(T, wait=True)
901-
# s = smb.tranimate(T, movie=True)
902-
# with open("zz.html", "w") as f:
903-
# print(f"<html>{s}</html>", file=f)
904-
905883
T = smb.rot2(2)
906884
# smb.tranimate2(T, wait=True)
907885
s = smb.tranimate2(T, movie=True)

spatialmath/base/argcheck.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import Any, cast, overload, Callable
1414

1515

16-
1716
# valid scalar types
1817
_scalartypes = (int, np.integer, float, np.floating) + symtype
1918

@@ -544,7 +543,6 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> float | NDArray:
544543
:seealso: :func:`getvector`
545544
"""
546545
if not isinstance(v, Iterable) and dim == 0:
547-
548546
# scalar in, scalar out
549547
if unit == "rad":
550548
return v

spatialmath/base/graphics.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@
4343
from mpl_toolkits.mplot3d import Axes3D
4444

4545

46-
4746
# TODO
4847
# return a redrawer object, that can be used for animation
4948

5049
# =========================== 2D shapes =================================== #
5150

51+
5252
def plot_text(
5353
pos: ArrayLike2,
5454
text: str,
@@ -101,6 +101,7 @@ def plot_text(
101101
handle = ax.text(pos[0], pos[1], text, color=color, **kwargs)
102102
return [handle]
103103

104+
104105
def plot_point(
105106
pos: ArrayLike2,
106107
marker: str | None = "bs",
@@ -270,6 +271,7 @@ def plot_point(
270271
)
271272
return handles
272273

274+
273275
def plot_homline(
274276
lines: ArrayLike3 | NDArray,
275277
*args,
@@ -336,6 +338,7 @@ def plot_homline(
336338

337339
return handles
338340

341+
339342
def plot_box(
340343
*fmt: str | None,
341344
lbrt: ArrayLike4 | None = None,
@@ -509,6 +512,7 @@ def plot_box(
509512

510513
return r
511514

515+
512516
def plot_arrow(
513517
start: ArrayLike2,
514518
end: ArrayLike2,
@@ -609,6 +613,7 @@ def plot_arrow(
609613
label = " " + label
610614
ax.text(*pos, label, **opt)
611615

616+
612617
def plot_polygon(
613618
vertices: NDArray, *fmt, close: bool = False, **kwargs
614619
) -> list[Artist]:
@@ -643,6 +648,7 @@ def plot_polygon(
643648
vertices = np.hstack((vertices, vertices[:, [0]]))
644649
return _render2D(vertices, fmt=fmt, **kwargs)
645650

651+
646652
def _render2D(
647653
vertices: NDArray,
648654
pose=None,
@@ -668,6 +674,7 @@ def _render2D(
668674
r = plt.plot(vertices[0, :], vertices[1, :], *fmt, **kwargs)
669675
return r
670676

677+
671678
def circle(
672679
centre: ArrayLike2 = (0, 0),
673680
radius: float = 1,
@@ -707,6 +714,7 @@ def circle(
707714
else:
708715
return np.array((x, y))
709716

717+
710718
def plot_circle(
711719
radius: float,
712720
centre: ArrayLike2,
@@ -763,6 +771,7 @@ def plot_circle(
763771
handles.append(ax.plot(xy[0, :], xy[1, :], *fmt, **kwargs))
764772
return handles
765773

774+
766775
def ellipse(
767776
E: R2x2,
768777
centre: ArrayLike2 | None = (0, 0),
@@ -826,6 +835,7 @@ def ellipse(
826835
e = s * sqrtm(E) @ xy + np.array(centre, ndmin=2).T
827836
return e
828837

838+
829839
def plot_ellipse(
830840
E: R2x2,
831841
centre: ArrayLike2,
@@ -893,8 +903,10 @@ def plot_ellipse(
893903
else:
894904
plt.plot(xy[0, :], xy[1, :], *fmt, **kwargs)
895905

906+
896907
# =========================== 3D shapes =================================== #
897908

909+
898910
def sphere(
899911
radius: float | None = 1,
900912
centre: ArrayLike3 | None = (0, 0, 0),
@@ -927,6 +939,7 @@ def sphere(
927939

928940
return (x, y, z)
929941

942+
930943
def plot_sphere(
931944
radius: float,
932945
centre: ArrayLike3 | None = (0, 0, 0),
@@ -995,6 +1008,7 @@ def plot_sphere(
9951008

9961009
return handles
9971010

1011+
9981012
def ellipsoid(
9991013
E: R2x2,
10001014
centre: ArrayLike3 | None = (0, 0, 0),
@@ -1042,16 +1056,14 @@ def ellipsoid(
10421056

10431057
x, y, z = sphere() # unit sphere
10441058
centre = smb.getvector(centre, 3, out="col")
1045-
e = (
1046-
scale * sqrtm(E) @ np.array([x.flatten(), y.flatten(), z.flatten()])
1047-
+ centre
1048-
)
1059+
e = scale * sqrtm(E) @ np.array([x.flatten(), y.flatten(), z.flatten()]) + centre
10491060
return (
10501061
e[0, :].reshape(x.shape),
10511062
e[1, :].reshape(x.shape),
10521063
e[2, :].reshape(x.shape),
10531064
)
10541065

1066+
10551067
def plot_ellipsoid(
10561068
E: R3x3,
10571069
centre: ArrayLike3 | None = (0, 0, 0),
@@ -1111,6 +1123,7 @@ def plot_ellipsoid(
11111123
handle = _render3D(ax, X, Y, Z, **kwargs)
11121124
return [handle]
11131125

1126+
11141127
def cylinder(
11151128
center_x: float,
11161129
center_y: float,
@@ -1149,6 +1162,7 @@ def cylinder(
11491162
Y = radius * np.sin(theta_grid) + center_y
11501163
return X, Y, Z
11511164

1165+
11521166
# https://stackoverflow.com/questions/30715083/python-plotting-a-wireframe-3d-cuboid
11531167
# https://stackoverflow.com/questions/26874791/disconnected-surfaces-when-plotting-cones
11541168
def plot_cylinder(
@@ -1230,6 +1244,7 @@ def plot_cylinder(
12301244

12311245
return handles
12321246

1247+
12331248
def plot_cone(
12341249
radius: float,
12351250
height: float,
@@ -1298,9 +1313,7 @@ def plot_cone(
12981313

12991314
handles = []
13001315
handles.append(_render3D(ax, X, Y, Z, filled=filled, **kwargs))
1301-
handles.append(
1302-
_render3D(ax, X, (2 * centre[1] - Y), Z, filled=filled, **kwargs)
1303-
)
1316+
handles.append(_render3D(ax, X, (2 * centre[1] - Y), Z, filled=filled, **kwargs))
13041317

13051318
if ends and kwargs.get("filled", default=False):
13061319
floor = Circle(centre[:2], radius, **kwargs)
@@ -1313,6 +1326,7 @@ def plot_cone(
13131326

13141327
return handles
13151328

1329+
13161330
def plot_cuboid(
13171331
sides: ArrayLike3 = (1, 1, 1),
13181332
centre: ArrayLike3 | None = (0, 0, 0),
@@ -1408,6 +1422,7 @@ def plot_cuboid(
14081422
ax.add_collection3d(collection)
14091423
return collection
14101424

1425+
14111426
def _render3D(
14121427
ax: Axes,
14131428
X: NDArray,
@@ -1445,6 +1460,7 @@ def _render3D(
14451460
kwargs["colors"] = color
14461461
return ax.plot_wireframe(X, Y, Z, **kwargs)
14471462

1463+
14481464
def _axes_dimensions(ax: Axes) -> int:
14491465
"""
14501466
Dimensions of axes
@@ -1468,13 +1484,16 @@ def _axes_dimensions(ax: Axes) -> int:
14681484
# print("_axes_dimensions ", ax, ret)
14691485
return ret
14701486

1487+
14711488
def axes_get_limits(ax: Axes) -> NDArray:
14721489
return np.r_[ax.get_xlim(), ax.get_ylim()]
14731490

1491+
14741492
def axes_get_scale(ax: Axes) -> float:
14751493
limits = axes_get_limits(ax)
14761494
return max(abs(limits[1] - limits[0]), abs(limits[3] - limits[2]))
14771495

1496+
14781497
@overload
14791498
def axes_logic(
14801499
ax: Axes | None,
@@ -1483,6 +1502,7 @@ def axes_logic(
14831502
new: bool | None = False,
14841503
) -> Axes: ...
14851504

1505+
14861506
@overload
14871507
def axes_logic(
14881508
ax: Axes | None,
@@ -1492,6 +1512,7 @@ def axes_logic(
14921512
new: bool | None = False,
14931513
) -> Axes3D: ...
14941514

1515+
14951516
def axes_logic(
14961517
ax: Axes | Axes3D | None,
14971518
dimensions: int,
@@ -1566,6 +1587,7 @@ def axes_logic(
15661587

15671588
return ax
15681589

1590+
15691591
def plotvol2(
15701592
dim: ArrayLike | None = None,
15711593
ax: Axes | None = None,
@@ -1621,6 +1643,7 @@ def plotvol2(
16211643
ax._plotvol = True
16221644
return ax
16231645

1646+
16241647
def plotvol3(
16251648
dim: ArrayLike | None = None,
16261649
ax: Axes | None = None,
@@ -1682,6 +1705,7 @@ def plotvol3(
16821705
ax._plotvol = True
16831706
return ax
16841707

1708+
16851709
def expand_dims(dim: ArrayLike | None = None, nd: int = 2) -> NDArray:
16861710
"""
16871711
Expand compact axis dimensions
@@ -1729,6 +1753,7 @@ def expand_dims(dim: ArrayLike | None = None, nd: int = 2) -> NDArray:
17291753
else:
17301754
raise ValueError("nd is 2 or 3")
17311755

1756+
17321757
def isnotebook() -> bool:
17331758
"""
17341759
Determine if code is being run from a Jupyter notebook
@@ -1751,6 +1776,7 @@ def isnotebook() -> bool:
17511776
except NameError:
17521777
return False # Probably standard Python interpreter
17531778

1779+
17541780
if __name__ == "__main__":
17551781
import pathlib
17561782

0 commit comments

Comments
 (0)