GLSMultipleLinearRegression.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.stat.regression;

  18. import org.apache.commons.math3.linear.LUDecomposition;
  19. import org.apache.commons.math3.linear.RealMatrix;
  20. import org.apache.commons.math3.linear.Array2DRowRealMatrix;
  21. import org.apache.commons.math3.linear.RealVector;

  22. /**
  23.  * The GLS implementation of multiple linear regression.
  24.  *
  25.  * GLS assumes a general covariance matrix Omega of the error
  26.  * <pre>
  27.  * u ~ N(0, Omega)
  28.  * </pre>
  29.  *
  30.  * Estimated by GLS,
  31.  * <pre>
  32.  * b=(X' Omega^-1 X)^-1X'Omega^-1 y
  33.  * </pre>
  34.  * whose variance is
  35.  * <pre>
  36.  * Var(b)=(X' Omega^-1 X)^-1
  37.  * </pre>
  38.  * @since 2.0
  39.  */
  40. public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression {

  41.     /** Covariance matrix. */
  42.     private RealMatrix Omega;

  43.     /** Inverse of covariance matrix. */
  44.     private RealMatrix OmegaInverse;

  45.     /** Replace sample data, overriding any previous sample.
  46.      * @param y y values of the sample
  47.      * @param x x values of the sample
  48.      * @param covariance array representing the covariance matrix
  49.      */
  50.     public void newSampleData(double[] y, double[][] x, double[][] covariance) {
  51.         validateSampleData(x, y);
  52.         newYSampleData(y);
  53.         newXSampleData(x);
  54.         validateCovarianceData(x, covariance);
  55.         newCovarianceData(covariance);
  56.     }

  57.     /**
  58.      * Add the covariance data.
  59.      *
  60.      * @param omega the [n,n] array representing the covariance
  61.      */
  62.     protected void newCovarianceData(double[][] omega){
  63.         this.Omega = new Array2DRowRealMatrix(omega);
  64.         this.OmegaInverse = null;
  65.     }

  66.     /**
  67.      * Get the inverse of the covariance.
  68.      * <p>The inverse of the covariance matrix is lazily evaluated and cached.</p>
  69.      * @return inverse of the covariance
  70.      */
  71.     protected RealMatrix getOmegaInverse() {
  72.         if (OmegaInverse == null) {
  73.             OmegaInverse = new LUDecomposition(Omega).getSolver().getInverse();
  74.         }
  75.         return OmegaInverse;
  76.     }

  77.     /**
  78.      * Calculates beta by GLS.
  79.      * <pre>
  80.      *  b=(X' Omega^-1 X)^-1X'Omega^-1 y
  81.      * </pre>
  82.      * @return beta
  83.      */
  84.     @Override
  85.     protected RealVector calculateBeta() {
  86.         RealMatrix OI = getOmegaInverse();
  87.         RealMatrix XT = getX().transpose();
  88.         RealMatrix XTOIX = XT.multiply(OI).multiply(getX());
  89.         RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse();
  90.         return inverse.multiply(XT).multiply(OI).operate(getY());
  91.     }

  92.     /**
  93.      * Calculates the variance on the beta.
  94.      * <pre>
  95.      *  Var(b)=(X' Omega^-1 X)^-1
  96.      * </pre>
  97.      * @return The beta variance matrix
  98.      */
  99.     @Override
  100.     protected RealMatrix calculateBetaVariance() {
  101.         RealMatrix OI = getOmegaInverse();
  102.         RealMatrix XTOIX = getX().transpose().multiply(OI).multiply(getX());
  103.         return new LUDecomposition(XTOIX).getSolver().getInverse();
  104.     }


  105.     /**
  106.      * Calculates the estimated variance of the error term using the formula
  107.      * <pre>
  108.      *  Var(u) = Tr(u' Omega^-1 u)/(n-k)
  109.      * </pre>
  110.      * where n and k are the row and column dimensions of the design
  111.      * matrix X.
  112.      *
  113.      * @return error variance
  114.      * @since 2.2
  115.      */
  116.     @Override
  117.     protected double calculateErrorVariance() {
  118.         RealVector residuals = calculateResiduals();
  119.         double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
  120.         return t / (getX().getRowDimension() - getX().getColumnDimension());

  121.     }

  122. }