Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions src/main/java/com/thealgorithms/matrix/QRDecomposition.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package com.thealgorithms.matrix;

/**
* @brief Implementation of QR Decomposition using the Gram-Schmidt process
* @details Decomposes a matrix A into an orthogonal matrix Q and an upper
* triangular matrix R such that A = Q * R. The Gram-Schmidt process
* orthogonalizes the columns of A to produce Q, and R is computed as Q^T * A.
* This decomposition is useful for solving linear least squares problems,
* eigenvalue computations, and numerical stability in linear algebra.
* @see <a href="https://en.wikipedia.org/wiki/QR_decomposition">QR Decomposition</a>
*/
public final class QRDecomposition {

private QRDecomposition() {
}

/**
* A helper class to store both Q and R matrices
*/
public static class QR {
private final double[][] q;
private final double[][] r;

QR(double[][] q, double[][] r) {
this.q = q;
this.r = r;
}

public double[][] getQ() {
return q;
}

public double[][] getR() {
return r;
}
}

/**
* @brief Performs QR decomposition on a matrix using the Gram-Schmidt process
* @param matrix the input matrix (m x n)
* @return QR object containing orthogonal matrix Q (m x n) and upper triangular matrix R (n x n)
* @throws IllegalArgumentException if the matrix is null, empty, or has invalid rows
*/
public static QR decompose(double[][] matrix) {
validateInputMatrix(matrix);

int m = matrix.length;
int n = matrix[0].length;

double[][] q = new double[m][n];
double[][] r = new double[n][n];

for (int j = 0; j < n; j++) {
double[] v = getColumn(matrix, j);

for (int i = 0; i < j; i++) {
double[] qi = getColumn(q, i);
r[i][j] = dotProduct(qi, v);
v = subtractVectors(v, scalarMultiply(qi, r[i][j]));
}

r[j][j] = norm(v);
if (r[j][j] == 0) {
throw new ArithmeticException("Matrix is rank deficient. Cannot perform QR decomposition.");
}
double[] qj = scalarMultiply(v, 1.0 / r[j][j]);
setColumn(q, j, qj);
}

return new QR(q, r);
}

private static double[] getColumn(double[][] matrix, int col) {
int m = matrix.length;
double[] column = new double[m];
for (int i = 0; i < m; i++) {
column[i] = matrix[i][col];
}
return column;
}

private static void setColumn(double[][] matrix, int col, double[] column) {
for (int i = 0; i < matrix.length; i++) {
matrix[i][col] = column[i];
}
}

private static double dotProduct(double[] a, double[] b) {
double sum = 0;
for (int i = 0; i < a.length; i++) {
sum += a[i] * b[i];
}
return sum;
}

private static double[] subtractVectors(double[] a, double[] b) {
double[] result = new double[a.length];
for (int i = 0; i < a.length; i++) {
result[i] = a[i] - b[i];
}
return result;
}

private static double[] scalarMultiply(double[] v, double scalar) {
double[] result = new double[v.length];
for (int i = 0; i < v.length; i++) {
result[i] = v[i] * scalar;
}
return result;
}

private static double norm(double[] v) {
return Math.sqrt(dotProduct(v, v));
}

private static void validateInputMatrix(double[][] matrix) {
if (matrix == null) {
throw new IllegalArgumentException("The input matrix cannot be null");
}
if (matrix.length == 0) {
throw new IllegalArgumentException("The input matrix cannot be empty");
}
if (!hasValidRows(matrix)) {
throw new IllegalArgumentException("The input matrix cannot have null or empty rows");
}
if (isJaggedMatrix(matrix)) {
throw new IllegalArgumentException("The input matrix cannot be jagged");
}
}

private static boolean hasValidRows(double[][] matrix) {
for (double[] row : matrix) {
if (row == null || row.length == 0) {
return false;
}
}
return true;
}

private static boolean isJaggedMatrix(double[][] matrix) {
int numColumns = matrix[0].length;
for (double[] row : matrix) {
if (row.length != numColumns) {
return true;
}
}
return false;
}
}
115 changes: 115 additions & 0 deletions src/test/java/com/thealgorithms/matrix/QRDecompositionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package com.thealgorithms.matrix;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;

public class QRDecompositionTest {

private static final double DELTA = 1e-9;

@Test
public void testQRDecomposition2x2() {
double[][] matrix = {{12, -51}, {6, 167}};
QRDecomposition.QR qr = QRDecomposition.decompose(matrix);
double[][] q = qr.getQ();
double[][] r = qr.getR();

double[][] reconstructed = multiplyMatrices(q, r);
for (int i = 0; i < matrix.length; i++) {
assertArrayEquals(matrix[i], reconstructed[i], DELTA);
}
}

@Test
public void testQRDecomposition3x3() {
double[][] matrix = {{1, 1, 0}, {1, 0, 1}, {0, 1, 1}};
QRDecomposition.QR qr = QRDecomposition.decompose(matrix);
double[][] q = qr.getQ();
double[][] r = qr.getR();

double[][] reconstructed = multiplyMatrices(q, r);
for (int i = 0; i < matrix.length; i++) {
assertArrayEquals(matrix[i], reconstructed[i], DELTA);
}
}

@Test
public void testQROrthogonalColumns() {
double[][] matrix = {{1, 1, 0}, {1, 0, 1}, {0, 1, 1}};
QRDecomposition.QR qr = QRDecomposition.decompose(matrix);
double[][] q = qr.getQ();

for (int i = 0; i < q[0].length; i++) {
for (int j = i; j < q[0].length; j++) {
double dot = 0;
for (int k = 0; k < q.length; k++) {
dot += q[k][i] * q[k][j];
}
if (i == j) {
assertArrayEquals(new double[] {1.0}, new double[] {dot}, DELTA);
} else {
assertArrayEquals(new double[] {0.0}, new double[] {dot}, DELTA);
}
}
}
}

@Test
public void testRIsUpperTriangular() {
double[][] matrix = {{12, -51}, {6, 167}};
QRDecomposition.QR qr = QRDecomposition.decompose(matrix);
double[][] r = qr.getR();

for (int i = 1; i < r.length; i++) {
for (int j = 0; j < i; j++) {
assertArrayEquals(new double[] {0.0}, new double[] {r[i][j]}, DELTA);
}
}
}

@Test
public void testQRDecompositionIdentityMatrix() {
double[][] matrix = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}};
QRDecomposition.QR qr = QRDecomposition.decompose(matrix);
double[][] q = qr.getQ();
double[][] r = qr.getR();

for (int i = 0; i < matrix.length; i++) {
assertArrayEquals(matrix[i], q[i], DELTA);
assertArrayEquals(matrix[i], r[i], DELTA);
}
}

@Test
public void testQRDecompositionRankDeficientThrows() {
double[][] matrix = {{1, 2}, {2, 4}};
assertThrows(ArithmeticException.class, () -> QRDecomposition.decompose(matrix));
}

@Test
public void testQRDecompositionNullMatrixThrows() {
assertThrows(IllegalArgumentException.class, () -> QRDecomposition.decompose(null));
}

@Test
public void testQRDecompositionEmptyMatrixThrows() {
assertThrows(IllegalArgumentException.class, () -> QRDecomposition.decompose(new double[0][0]));
}

private static double[][] multiplyMatrices(double[][] a, double[][] b) {
int m = a.length;
int n = b[0].length;
int k = a[0].length;
double[][] result = new double[m][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
for (int p = 0; p < k; p++) {
result[i][j] += a[i][p] * b[p][j];
}
}
}
return result;
}
}
Loading