org.knowceans.sandbox.gauss
Class IgmmGibbsSampler

java.lang.Object
  extended by org.knowceans.sandbox.gauss.IgmmGibbsSampler

public class IgmmGibbsSampler
extends java.lang.Object

IgmmGibbsSampler implements the infinite Gaussian mixture model as a Gibbs sampler. The algorithm runs fully non-parametric, i.e., the entire set of parameters are estimated using Empirical Bayes and the only prior knowledge involved is the choice of distributions.

Rasmussen (NIPS-12) presented the underlying approach whose complete probability model is as follows:

 Mixture model:
 x       = sum_j^k pi_j * N(mu_j, 1/s_j)
 Data points:
 x | c_j ~ N(mu_j, 1/s_j)
 Mean hyperparameters: 
 mu_j    ~ N(lambda, 1/r)
 lambda  ~ N(mu_y, sigma_y^2)
 r       ~ Gamma(1, 1/sigma_y^2) 
         = 1/Z * r^.5 * exp(-r * sigma_y^2/2)
 Precision hyperparameters:
 s_j     ~ Gamma(beta, 1/w)
 w       ~ Gamma(1, sigma_y^2)
 pi_j    ~ Dirichlet(alpha/k)
 alpha   ~ Gamma(a,b)
 
Note that Rasmussen defines Gamma(a,b) as having mean b, not a*b, which is reflected in the method sampleGammaDist().

Author:
heinrich

Nested Class Summary
(package private)  class IgmmGibbsSampler.AlphaArms
          AlphaArms implements the ARMS update for the CRP hyperparameter
(package private)  class IgmmGibbsSampler.BetaArms
          BetaArms implements the ARMS update for the precision's precision hyperparameter
 
Field Summary
private  double alpha
          crp prior
private  IgmmGibbsSampler.AlphaArms alphaSampler
          sampler for alpha
private  double beta
          mean of s_j
private  IgmmGibbsSampler.BetaArms betaSampler
          sampler for beta
private  int burnIn
          burn-in period
private  int[] cc
          state (component) for each data point
 int debugLevel
          debug level for output (5=info, 1=error)
private  int growstep
          array grow step
private  int iterations
          max iterations
private  int k
          number of components
private  double lambda
          mean of mu_j
private  double[] mu
          component means
private  double muunrep
          mean of unrepresented components
private  double muy
          mean of the data
private  int n
          data size
private  double[] nn
          occupation numbers for each component (double for usage of double-valued methods)
private  double r
          precision of mu_j
private  boolean randomScan
          random scan or systematic scan
private  double[] s
          inverse component variances (precisions)
private  double sigmasqy
          variance of the data
private  double sunrep
          precision of unrepresented components
private  int thinInterval
          sampling lag
private  double w
          precision of s_j
private  double[] ysum
          component data sum
private  double[] yy
          vector of univariate data points
 
Constructor Summary
IgmmGibbsSampler(double[] data)
          Initialise the Gibbs sampler with data.
 
Method Summary
private  void addComponent()
          handle size of componentwise structures.
 void configure(int iterations, int burnIn, int thinInterval)
          set sampling conditions
private  void debug(int level, java.lang.String string)
          print debug information
(package private)  double[] getMean()
          get the mean of the components
(package private)  double[] getStdDev(double[] mean)
          get the standard deviation of the components
(package private)  double[] getWeights()
          get the mixture weights of the components
private  void gibbs()
          Main method: Select initial state ?
private static double[] increaseSize(double[] array, int step)
          Increase size of array.
(package private)  void initialState()
          Initialisation: starts with one class and assigns data-dependent piors (which Rasmussen justifies in his paper).
static void main(java.lang.String[] args)
          Driver with example data.
private  void removeComponent(int j)
          removes one component from the model
private static double[] removeElement(double[] array, int element)
          removes one element from the array
(package private)  double sampleAlpha()
          sample alpha using ARS.
(package private)  double sampleBeta()
          sample beta using ARS.
(package private)  int sampleC(int i)
          sample component association to data point i with likelihood.
(package private)  int sampleCrpC(int i)
          sample component association to data point i using Chinese restaurant process including likelihood term.
(package private)  int sampleCrpPriorC(int i)
          sample component association to data point i using Chinese restaurant process.
 double sampleGammaDist(double a, double b)
          Gamma distribution with mean as a parameter b (normally mean = a*b)
(package private)  double sampleLambda()
          sample the component means' mean.
(package private)  double sampleMu(int j)
          sample the means.
(package private)  double sampleMuUnrep()
          sample from prior on s for unrepresented classes
(package private)  double sampleNormalDist(double mu, double sigmaSquared)
          Normal distribution with variance as parameter instead of standard deviation.
(package private)  double[] samplePi()
          sample the component weights ~ Dirichlet.
(package private)  int samplePriorC(int i)
          sample component association to data point i using Dirichlet distribution.
(package private)  double sampleR()
          sample means' precision.
(package private)  double sampleS(int j)
          sample precision for component j.
(package private)  double sampleSUnrep()
          sample from prior on mu for unrepresented classes
(package private)  double sampleW()
          sample precisions' precision.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

yy

private double[] yy
vector of univariate data points


thinInterval

private int thinInterval
sampling lag


burnIn

private int burnIn
burn-in period


iterations

private int iterations
max iterations


growstep

private int growstep
array grow step


muy

private double muy
mean of the data


sigmasqy

private double sigmasqy
variance of the data


alpha

private double alpha
crp prior


alphaSampler

private IgmmGibbsSampler.AlphaArms alphaSampler
sampler for alpha


lambda

private double lambda
mean of mu_j


beta

private double beta
mean of s_j


betaSampler

private IgmmGibbsSampler.BetaArms betaSampler
sampler for beta


r

private double r
precision of mu_j


w

private double w
precision of s_j


mu

private double[] mu
component means


muunrep

private double muunrep
mean of unrepresented components


s

private double[] s
inverse component variances (precisions)


sunrep

private double sunrep
precision of unrepresented components


k

private int k
number of components


nn

private double[] nn
occupation numbers for each component (double for usage of double-valued methods)


cc

private int[] cc
state (component) for each data point


n

private int n
data size


ysum

private double[] ysum
component data sum


randomScan

private boolean randomScan
random scan or systematic scan


debugLevel

public int debugLevel
debug level for output (5=info, 1=error)

Constructor Detail

IgmmGibbsSampler

public IgmmGibbsSampler(double[] data)
Initialise the Gibbs sampler with data.

Parameters:
data -
Method Detail

configure

public void configure(int iterations,
                      int burnIn,
                      int thinInterval)
set sampling conditions

Parameters:
iterations -
burnIn -
thinInterval -

initialState

void initialState()
Initialisation: starts with one class and assigns data-dependent piors (which Rasmussen justifies in his paper).


addComponent

private void addComponent()
handle size of componentwise structures.

Note: We use arrays for components more readable syntax and possibly speed loss during all cast operations when accessing a Vector. Therefore all loops over components should explicitly use k, not, e.g., mu.length. The problem with this approach is that it is hard to remove unoccupied classes


increaseSize

private static double[] increaseSize(double[] array,
                                     int step)
Increase size of array.

Parameters:
array -
overhead -
step -
Returns:
longer array

removeComponent

private void removeComponent(int j)
removes one component from the model

Parameters:
j -

removeElement

private static double[] removeElement(double[] array,
                                      int element)
removes one element from the array

Parameters:
array -
overhead -
step -
Returns:
shorter array

getMean

double[] getMean()
get the mean of the components

Returns:

getStdDev

double[] getStdDev(double[] mean)
get the standard deviation of the components

Parameters:
mean -
Returns:

getWeights

double[] getWeights()
get the mixture weights of the components

Returns:

sampleAlpha

double sampleAlpha()
sample alpha using ARS. Eq. (15)

Returns:

sampleBeta

double sampleBeta()
sample beta using ARS. Eq. (9b)

Returns:

samplePriorC

int samplePriorC(int i)
sample component association to data point i using Dirichlet distribution. Eq. (13) (Not necessary; can be sampled via sampleC())

Parameters:
i -
Returns:

sampleC

int sampleC(int i)
sample component association to data point i with likelihood. Eq. (13) and (1)

Parameters:
i -
Returns:

sampleCrpPriorC

int sampleCrpPriorC(int i)
sample component association to data point i using Chinese restaurant process. Eq. (16). If a new component is sampled, the value of k is returned and a new component can be added by the caller.

Parameters:
i -
Returns:

sampleCrpC

int sampleCrpC(int i)
sample component association to data point i using Chinese restaurant process including likelihood term. Eq. (17)

Parameters:
i -
Returns:

sampleGammaDist

public double sampleGammaDist(double a,
                              double b)
Gamma distribution with mean as a parameter b (normally mean = a*b)

Parameters:
a -
b -
Returns:

sampleNormalDist

double sampleNormalDist(double mu,
                        double sigmaSquared)
Normal distribution with variance as parameter instead of standard deviation.

Parameters:
a -
b -
Returns:

sampleLambda

double sampleLambda()
sample the component means' mean. Eq. (5)

Returns:

sampleMu

double sampleMu(int j)
sample the means.

TODO: possibility to update means whenever component associations are changed --> no need to calc mu[j] here

Parameters:
j -
Returns:

sampleMuUnrep

double sampleMuUnrep()
sample from prior on s for unrepresented classes

Returns:

samplePi

double[] samplePi()
sample the component weights ~ Dirichlet. Eq. (10). (Not necessary since posterior for occupation counts is used)

Returns:

sampleR

double sampleR()
sample means' precision. Eq. (5)

Returns:

sampleS

double sampleS(int j)
sample precision for component j. Eq. (8)

Parameters:
j -
Returns:

sampleSUnrep

double sampleSUnrep()
sample from prior on mu for unrepresented classes

Returns:

sampleW

double sampleW()
sample precisions' precision. Eq. (9)

Returns:

gibbs

private void gibbs()
Main method: Select initial state ? Repeat a large number of times: 1. Select an element 2. Update conditional on other elements. If appropriate, output summary for each run.

Parameters:
k -
probs -
mean -
sigma -

debug

private void debug(int level,
                   java.lang.String string)
print debug information

Parameters:
level - debug level (5=info to 1=error)
string -

main

public static void main(java.lang.String[] args)
Driver with example data.

Parameters:
args -