Skip to content

Commit b93fda7

Browse files
committed
allow return dtype to be specified
1 parent cfbe764 commit b93fda7

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

spatialmath/base/argcheck.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def verifymatrix(m, shape):
9999
# and not np.iscomplex(m) checks every element, would need to be not np.any(np.iscomplex(m)) which seems expensive
100100

101101

102-
def getvector(v, dim=None, out='array'):
102+
def getvector(v, dim=None, out='array', dtype=np.float64):
103103
"""
104104
Return a vector value
105105
@@ -108,6 +108,8 @@ def getvector(v, dim=None, out='array'):
108108
:type dim: int or None
109109
:param out: output format, default is 'array'
110110
:type out: str
111+
:param dtype: datatype for NumPy array return (default np.float64)
112+
:type dtype: NumPy type
111113
:return: vector value in specified format
112114
113115
The passed vector can be any of:
@@ -126,13 +128,17 @@ def getvector(v, dim=None, out='array'):
126128
'row' row vector, a 2D NumPy array, shape=(1,N)
127129
'col' column vector, 2D NumPy array, shape=(N,1)
128130
========== ===============================================
131+
132+
For 'array', 'row' or 'col' the NumPy dtype defaults to ``np.float64`` but
133+
can be overriden using the ``dtype`` argument.
129134
"""
130135
if isinstance(v, (int, np.int64, float)) or (
131136
_sympy and isinstance(v, sympy.Expr)): # handle scalar case
132137
v = [v]
133138

134139
if isinstance(v, (list, tuple)):
135-
dt = np.float64 # return np arrays of this type
140+
# list or tuple was passed in
141+
dt = dtype
136142
if _sympy:
137143
if any([isinstance(x, sympy.Expr) for x in v]):
138144
dt = None
@@ -149,6 +155,7 @@ def getvector(v, dim=None, out='array'):
149155
return np.array(v, dtype=dt).reshape(-1, 1)
150156
else:
151157
raise ValueError("invalid output specifier")
158+
152159
elif isinstance(v, np.ndarray):
153160
s = v.shape
154161
if dim is not None:
@@ -158,7 +165,7 @@ def getvector(v, dim=None, out='array'):
158165
v = v.flatten()
159166

160167
if v.dtype.kind != 'O':
161-
dt = np.float64
168+
dt = dtype
162169

163170
if out == 'sequence':
164171
return list(v.flatten())

0 commit comments

Comments
 (0)