001/*
002 * JScience - Java(TM) Tools and Libraries for the Advancement of Sciences.
003 * Copyright (C) 2006 - JScience (http://jscience.org/)
004 * All rights reserved.
005 * 
006 * Permission to use, copy, modify, and distribute this software is
007 * freely granted, provided that this notice is preserved.
008 */
009package org.jscience.mathematics.vector;
010
011import java.util.Comparator;
012
013import org.jscience.mathematics.structure.Field;
014import org.jscience.mathematics.number.Number;
015
016import javolution.context.LocalContext;
017import javolution.context.ObjectFactory;
018import javolution.util.FastTable;
019import javolution.util.Index;
020
021/**
022 * <p> This class represents the decomposition of a {@link Matrix matrix} 
023 *     <code>A</code> into a product of a {@link #getLower lower} 
024 *     and {@link #getUpper upper} triangular matrices, <code>L</code>
025 *     and <code>U</code> respectively, such as <code>A = P·L·U<code> with 
026 *     <code>P<code> a {@link #getPermutation permutation} matrix.</p>
027 *     
028 * <p> This decomposition</a> is typically used to resolve linear systems
029 *     of equations (Gaussian elimination) or to calculate the determinant
030 *     of a square {@link Matrix} (<code>O(m³)</code>).</p>
031 *     
032 * <p> Numerical stability is guaranteed through pivoting if the
033 *     {@link Field} elements are {@link Number numbers}
034 *     For others elements types, numerical stability can be ensured by setting
035 *     the {@link javolution.context.LocalContext context-local} pivot 
036 *     comparator (see {@link #setPivotComparator}).</p>
037 *     
038 * <p> Pivoting can be disabled by setting the {@link #setPivotComparator 
039 *     pivot comparator} to <code>null</code> ({@link #getPermutation P} 
040 *     is then the matrix identity).</p>
041 *     
042 * @author <a href="mailto:jean-marie@dautelle.com">Jean-Marie Dautelle</a>
043 * @version 3.3, January 2, 2007
044 * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">
045 *      Wikipedia: LU decomposition</a>
046 */
047public final class LUDecomposition<F extends Field<F>>  {
048
049    /**
050     * Holds the default comparator for pivoting.
051     */
052    public static final Comparator<Field> NUMERIC_COMPARATOR = new Comparator<Field>() {
053
054        @SuppressWarnings("unchecked")
055        public int compare(Field left, Field right) {
056            if ((left instanceof Number) && (right instanceof Number))
057                return ((Number) left).isLargerThan((Number) right) ? 1 : -1;
058            if (left.equals(left.plus(left))) // Zero
059                return -1;
060            if (right.equals(right.plus(right))) // Zero
061                return 1;
062            return 0;
063        }
064    };
065
066    /**
067     * Holds the local comparator.
068     */
069    private static final LocalContext.Reference<Comparator<Field>> PIVOT_COMPARATOR = new LocalContext.Reference<Comparator<Field>>(
070            NUMERIC_COMPARATOR);
071
072    /**
073     * Holds the dimension of the square matrix source.
074     */
075    private int _n;
076
077    /**
078     * Holds the pivots indexes.
079     */
080    private final FastTable<Index> _pivots = new FastTable<Index>();
081
082    /**
083     * Holds the LU elements.
084     */
085    private DenseMatrix<F> _LU;
086
087    /**
088     * Holds the number of permutation performed.
089     */
090    private int _permutationCount;
091
092    /**
093     * Returns the lower/upper decomposition of the specified matrix.
094     *
095     * @param  source the matrix for which the decomposition is calculated.
096     * @return the lower/upper decomposition of the specified matrix.
097     * @throws DimensionException if the specified matrix is not square.
098     */
099    @SuppressWarnings("unchecked")
100    public static <F extends Field<F>> LUDecomposition<F> valueOf(
101            Matrix<F> source) {
102        if (!source.isSquare())
103            throw new DimensionException("Matrix is not square");
104        int dimension = source.getNumberOfRows();
105        LUDecomposition lu = FACTORY.object();
106        lu._n = dimension;
107        lu._permutationCount = 0;
108        lu.construct(source);
109        return lu;
110    }
111
112    /**
113     * Constructs the LU decomposition of the specified matrix.
114     * We make the choise of Lii = ONE (diagonal elements of the
115     * lower triangular matrix are multiplicative identities).
116     *
117     * @param  source the matrix to decompose.
118     * @throws MatrixException if the matrix source is not square.
119     */
120    private void construct(Matrix<F> source) {
121        _LU = source instanceof DenseMatrix ? ((DenseMatrix<F>) source).copy()
122                : DenseMatrix.valueOf(source);
123        _pivots.clear();
124        for (int i = 0; i < _n; i++) {
125            _pivots.add(Index.valueOf(i));
126        }
127
128        // Main loop.
129        Comparator<Field> cmp = LUDecomposition.getPivotComparator();
130        final int n = _n;
131        for (int k = 0; k < _n; k++) {
132
133            if (cmp != null) { // Pivoting enabled.
134                // Rearranges the rows so that the absolutely largest
135                // elements of the matrix source in each column lies
136                // in the diagonal.
137                int pivot = k;
138                for (int i = k + 1; i < n; i++) {
139                    if (cmp.compare(_LU.get(i, k), _LU.get(pivot, k)) > 0) {
140                        pivot = i;
141                    }
142                }
143                if (pivot != k) { // Exchanges.
144                    for (int j = 0; j < n; j++) {
145                        F tmp = _LU.get(pivot, j);
146                        _LU.set(pivot, j, _LU.get(k, j));
147                        _LU.set(k, j, tmp);
148                    }
149                    int j = _pivots.get(pivot).intValue();
150                    _pivots.set(pivot, _pivots.get(k));
151                    _pivots.set(k, Index.valueOf(j));
152                    _permutationCount++;
153                }
154            }
155
156            // Computes multipliers and eliminate k-th column.
157            F lukkInv = _LU.get(k, k).inverse();
158            for (int i = k + 1; i < n; i++) {
159                // Multiplicative order is important
160                // for non-commutative elements.
161                _LU.set(i, k, _LU.get(i, k).times(lukkInv));
162                for (int j = k + 1; j < n; j++) {
163                    _LU.set(i, j, _LU.get(i, j).plus(
164                            _LU.get(i, k).times(_LU.get(k, j).opposite())));
165                }
166            }
167        }
168    }
169
170    /**
171     * Sets the {@link javolution.context.LocalContext local} comparator used 
172     * for pivoting or <code>null</code> to disable pivoting.
173     *
174     * @param  cmp the comparator for pivoting or <code>null</code>.
175     */
176    public static void setPivotComparator(Comparator<Field> cmp) {
177        PIVOT_COMPARATOR.set(cmp);
178    }
179
180    /**
181     * Returns the {@link javolution.context.LocalContext local} 
182     * comparator used for pivoting or <code>null</code> if pivoting 
183     * is not performed (default {@link #NUMERIC_COMPARATOR}).
184     *
185     * @return the comparator for pivoting or <code>null</code>.
186     */
187    public static Comparator<Field> getPivotComparator() {
188        return PIVOT_COMPARATOR.get();
189    }
190
191    /**
192     * Returns the solution X of the equation: A * X = B  with
193     * <code>this = A.lu()</code> using back and forward substitutions.
194     *
195     * @param  B the input matrix.
196     * @return the solution X = (1 / A) * B.
197     * @throws DimensionException if the dimensions do not match.
198     */
199    public DenseMatrix<F> solve(Matrix<F> B) {
200        if (_n != B.getNumberOfRows())
201            throw new DimensionException("Input vector has "
202                    + B.getNumberOfRows() + " rows instead of " + _n);
203
204        // Copies B with pivoting.
205        final int n = B.getNumberOfColumns();
206        DenseMatrix<F> X = createNullDenseMatrix(_n, n);
207        for (int i = 0; i < _n; i++) {
208            for (int j = 0; j < n; j++) {
209                X.set(i, j, B.get(_pivots.get(i).intValue(), j));
210            }
211        }
212
213        // Solves L * Y = pivot(B)
214        for (int k = 0; k < _n; k++) {
215            for (int i = k + 1; i < _n; i++) {
216                F luik = _LU.get(i, k);
217                for (int j = 0; j < n; j++) {
218                    X.set(i, j, X.get(i, j).plus(
219                            luik.times(X.get(k, j).opposite())));
220                }
221            }
222        }
223
224        // Solves U * X = Y;
225        for (int k = _n - 1; k >= 0; k--) {
226            for (int j = 0; j < n; j++) {
227                X.set(k, j, (_LU.get(k, k).inverse()).times(X.get(k, j)));
228            }
229            for (int i = 0; i < k; i++) {
230                F luik = _LU.get(i, k);
231                for (int j = 0; j < n; j++) {
232                    X.set(i, j, X.get(i, j).plus(
233                            luik.times(X.get(k, j).opposite())));
234                }
235            }
236        }
237        return X;
238    }
239
240    private DenseMatrix<F> createNullDenseMatrix(int m, int n) {
241        DenseMatrix<F> M = DenseMatrix.newInstance(n, false);
242        for (int i = 0; i < m; i++) {
243            DenseVector<F> V = DenseVector.newInstance();
244            M._rows.add(V);
245            for (int j = 0; j < n; j++) {
246                V._elements.add(null);
247            }
248        }
249        return M;
250    }
251
252    /**
253     * Returns the solution X of the equation: A * X = Identity  with
254     * <code>this = A.lu()</code> using back and forward substitutions.
255     *
256     * @return <code>this.solve(Identity)</code>
257     */
258    public DenseMatrix<F> inverse() {
259        // Calculates inv(U).
260        final int n = _n;
261        DenseMatrix<F> R = createNullDenseMatrix(n, n);
262        for (int i = 0; i < n; i++) {
263            for (int j = i; j < n; j++) {
264                R.set(i, j, _LU.get(i, j));
265            }
266        }
267        for (int j = n - 1; j >= 0; j--) {
268            R.set(j, j, R.get(j, j).inverse());
269            for (int i = j - 1; i >= 0; i--) {
270                F sum = R.get(i, j).times(R.get(j, j).opposite());
271                for (int k = j - 1; k > i; k--) {
272                    sum = sum.plus(R.get(i, k).times(R.get(k, j).opposite()));
273                }
274                R.set(i, j, (R.get(i, i).inverse()).times(sum));
275            }
276        }
277        // Solves inv(A) * L = inv(U)
278        for (int i = 0; i < n; i++) {
279            for (int j = n - 2; j >= 0; j--) {
280                for (int k = j + 1; k < n; k++) {
281                    F lukj = _LU.get(k, j);
282                    if (R.get(i, j) != null) {
283                        R.set(i, j, R.get(i, j).plus(
284                                R.get(i, k).times(lukj.opposite())));
285                    } else {
286                        R.set(i, j, R.get(i, k).times(lukj.opposite()));
287                    }
288                }
289            }
290        }
291        // Swaps columns (reverses pivots permutations).
292        FastTable<F> tmp = FastTable.newInstance();
293        for (int i = 0; i < n; i++) {
294            tmp.reset();
295            for (int j = 0; j < n; j++) {
296                tmp.add(R.get(i, j));
297            }
298            for (int j = 0; j < n; j++) {
299                R.set(i, _pivots.get(j).intValue(), tmp.get(j));
300            }
301        }
302        FastTable.recycle(tmp);
303        return R;
304    }
305
306    /**
307     * Returns the determinant of the {@link Matrix} having this
308     * decomposition.
309     *
310     * @return the determinant of the matrix source.
311     */
312    public F determinant() {
313        F product = _LU.get(0, 0);
314        for (int i = 1; i < _n; i++) {
315            product = product.times(_LU.get(i, i));
316        }
317        return ((_permutationCount & 1) == 0) ? product : product.opposite();
318    }
319
320    /**
321     * Returns the lower matrix decomposition (<code>L</code>) with diagonal
322     * elements equal to the multiplicative identity for F. 
323     *
324     * @param zero the additive identity for F.
325     * @param one the multiplicative identity for F.
326     * @return the lower matrix.
327     */
328    public DenseMatrix<F> getLower(F zero, F one) {
329        DenseMatrix<F> L = _LU.copy();
330        for (int j = 0; j < _n; j++) {
331            for (int i = 0; i < j; i++) {
332                L.set(i, j, zero);
333            }
334            L.set(j, j, one);
335        }
336        return L;
337    }
338
339    /**
340     * Returns the upper matrix decomposition (<code>U</code>). 
341     *
342     * @param zero the additive identity for F.
343     * @return the upper matrix.
344     */
345    public DenseMatrix<F> getUpper(F zero) {
346        DenseMatrix<F> U = _LU.copy();
347        for (int j = 0; j < _n; j++) {
348            for (int i = j + 1; i < _n; i++) {
349                U.set(i, j, zero);
350            }
351        }
352        return U;
353    }
354
355    /**
356     * Returns the permutation matrix (<code>P</code>). 
357     *
358     * @param zero the additive identity for F.
359     * @param one the multiplicative identity for F.
360     * @return the permutation matrix.
361     */
362    public SparseMatrix<F> getPermutation(F zero, F one) {
363        SparseMatrix<F> P = SparseMatrix.newInstance(_n, zero, false);
364        for (int i = 0; i < _n; i++) {
365            P.getRow(_pivots.get(i).intValue())._elements.put(Index.valueOf(i),
366                    one);
367        }
368        return P;
369    }
370
371    /**
372     * Returns the lower/upper decomposition in one single matrix. 
373     *
374     * @return the lower/upper matrix merged in a single matrix.
375     */
376    public DenseMatrix<F> getLU() {
377        return _LU;
378    }
379
380    /**
381     * Returns the pivots elements of this decomposition. 
382     *
383     * @return the row indices after permutation.
384     */
385    public FastTable<Index> getPivots() {
386        return _pivots;
387    }
388
389    
390    ///////////////////////
391    // Factory creation. //
392    ///////////////////////
393
394    private static final ObjectFactory<LUDecomposition> FACTORY = new ObjectFactory<LUDecomposition>() {
395        protected LUDecomposition create() {
396            return new LUDecomposition();
397        }
398
399        @SuppressWarnings("unchecked")
400        protected void cleanup(LUDecomposition lu) {
401            lu._LU = null;
402        }
403    };
404
405    private LUDecomposition() {
406    }
407
408}