Skip to content

Commit 34d8583

Browse files
committed
clean up code, taking advantage of the more general wrap methods
1 parent fd9d2ab commit 34d8583

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

src/fastLm.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -174,36 +174,35 @@ static inline lm do_lm(const MMatrixXd &X, const MVectorXd &y, int type)
174174

175175
extern "C" SEXP fastLm(SEXP Xs, SEXP ys, SEXP type) {
176176
try {
177-
const NumericMatrix X(Xs);
178-
const NumericVector y(ys);
179-
Index n = X.nrow(), p = X.ncol();
177+
const MMatrixXd X(as<MMatrixXd>(Xs));
178+
const MVectorXd y(as<MVectorXd>(ys));
179+
Index n = X.rows(), p = X.cols();
180180
if ((Index)y.size() != n)
181181
throw std::invalid_argument("size mismatch");
182-
const MVectorXd yy(y.begin(), n);
183-
const MMatrixXd XX(X.begin(), n, p);
184182

185-
lm ans = do_lm(XX, yy, ::Rf_asInteger(type));
186-
NumericVector coef = wrap(ans.coef());
187-
// install the names, if available
188-
List dimnames = X.attr("dimnames");
183+
lm ans = do_lm(X, y, ::Rf_asInteger(type));
184+
// Copy coefficients and install names, if available
185+
NumericVector coef = wrap(ans.coef());
186+
List dimnames = NumericMatrix(Xs).attr("dimnames");
189187
if (dimnames.size() > 1) {
190-
RObject colnames = dimnames[1];
188+
RObject colnames = dimnames[1];
191189
if (!(colnames).isNULL())
192190
coef.attr("names") = clone(CharacterVector(colnames));
193191
}
194192

195-
VectorXd resid = yy - ans.fitted();
196-
double s2 = resid.squaredNorm()/ans.df();
197-
PermutationType Pmat = PermutationType(p);
198-
Pmat.indices() = ans.perm();
199-
VectorXd dd = Pmat * ans.unsc().diagonal();
200-
ArrayXd se = (dd.array() * s2).sqrt();
193+
VectorXd resid = y - ans.fitted();
194+
double s2 = resid.squaredNorm()/ans.df();
195+
// Create the standard errors
196+
PermutationType Pmat = PermutationType(p);
197+
Pmat.indices() = ans.perm();
198+
VectorXd dd = Pmat * ans.unsc().diagonal();
199+
ArrayXd se = (dd.array() * s2).sqrt();
201200

202201
return List::create(_["coefficients"] = coef,
203202
_["se"] = se,
204203
_["rank"] = ans.rank(),
205204
_["df.residual"] = ans.df(),
206-
_["perm"] = IntegerVector(ans.perm().data(), ans.perm().data() + p),
205+
_["perm"] = ans.perm(),
207206
_["residuals"] = resid,
208207
_["s2"] = s2,
209208
_["fitted.values"] = ans.fitted(),

0 commit comments

Comments
 (0)