@@ -99,7 +99,7 @@ def verifymatrix(m, shape):
9999# and not np.iscomplex(m) checks every element, would need to be not np.any(np.iscomplex(m)) which seems expensive
100100
101101
102- def getvector (v , dim = None , out = 'array' ):
102+ def getvector (v , dim = None , out = 'array' , dtype = np . float64 ):
103103 """
104104 Return a vector value
105105
@@ -108,6 +108,8 @@ def getvector(v, dim=None, out='array'):
108108 :type dim: int or None
109109 :param out: output format, default is 'array'
110110 :type out: str
111+ :param dtype: datatype for NumPy array return (default np.float64)
112+ :type dtype: NumPy type
111113 :return: vector value in specified format
112114
113115 The passed vector can be any of:
@@ -126,13 +128,17 @@ def getvector(v, dim=None, out='array'):
126128 'row' row vector, a 2D NumPy array, shape=(1,N)
127129 'col' column vector, 2D NumPy array, shape=(N,1)
128130 ========== ===============================================
131+
132+ For 'array', 'row' or 'col' the NumPy dtype defaults to ``np.float64`` but
133+ can be overriden using the ``dtype`` argument.
129134 """
130135 if isinstance (v , (int , np .int64 , float )) or (
131136 _sympy and isinstance (v , sympy .Expr )): # handle scalar case
132137 v = [v ]
133138
134139 if isinstance (v , (list , tuple )):
135- dt = np .float64 # return np arrays of this type
140+ # list or tuple was passed in
141+ dt = dtype
136142 if _sympy :
137143 if any ([isinstance (x , sympy .Expr ) for x in v ]):
138144 dt = None
@@ -149,6 +155,7 @@ def getvector(v, dim=None, out='array'):
149155 return np .array (v , dtype = dt ).reshape (- 1 , 1 )
150156 else :
151157 raise ValueError ("invalid output specifier" )
158+
152159 elif isinstance (v , np .ndarray ):
153160 s = v .shape
154161 if dim is not None :
@@ -158,7 +165,7 @@ def getvector(v, dim=None, out='array'):
158165 v = v .flatten ()
159166
160167 if v .dtype .kind != 'O' :
161- dt = np . float64
168+ dt = dtype
162169
163170 if out == 'sequence' :
164171 return list (v .flatten ())
0 commit comments