org.knowceans.dirichlet.lda
Class LdaGibbsSampler

java.lang.Object
  extended by org.knowceans.dirichlet.lda.LdaGibbsSampler
All Implemented Interfaces:
java.io.Serializable
Direct Known Subclasses:
AtmGibbsSampler, HldaGibbsSampler, IldaGibbsSampler, LdaGibbsQuerySampler, LdaGibbsSamplerHyper

public class LdaGibbsSampler
extends java.lang.Object
implements java.io.Serializable

Gibbs sampler for estimating the best assignments of topics for words and documents in a corpus. The algorithm is introduced in Tom Griffiths' paper "Gibbs sampling in the generative model of Latent Dirichlet Allocation" (2002).

TODO: clean up constructor mess, so invalid inits become more difficult...

Author:
heinrich
See Also:
Serialized Form

Field Summary
 int backupIteration
          iteration in the last backup
protected  ExtLdaConfiguration conf
          Configuration object with the current parameters.
 int dispcol
           
protected  int numstats
          size of statistics
protected  double[][] phisum
          cumulative statistics of phi
protected  java.util.Random rand
          Random generator
private static long serialVersionUID
           
protected  LdaMarkovState state
          State variables of the Lda gibbs sampler.
protected  double[][] thetasum
          cumulative statistics of theta
 
Constructor Summary
protected LdaGibbsSampler()
          For subclasses who know what they do...
  LdaGibbsSampler(int[][] documents, int V, double alpha, double beta, int K, int iterations)
          Initialise the Gibbs sampler with data and standard values.
  LdaGibbsSampler(int[][] documents, int V, ExtLdaConfiguration conf)
          Initialise the sampler with the documents and the configuration.
  LdaGibbsSampler(int[][] documents, int V, ExtLdaConfiguration conf, java.util.Random rand)
          Initialise the sampler with the documents and the configuration.
  LdaGibbsSampler(ITermCorpus corpus, ExtLdaConfiguration conf)
          Initialise the corpus with
  LdaGibbsSampler(ITermCorpus corpus, ExtLdaConfiguration conf, java.util.Random rand)
          Initialise the corpus with
  LdaGibbsSampler(ITermCorpus corpus, LdaMarkovState state, ExtLdaConfiguration conf, java.util.Random rand)
          Initialise the sampler with an existing state.
protected LdaGibbsSampler(LdaMarkovState state, ExtLdaConfiguration conf, java.util.Random rand)
          Initialise the sampler with an existing state.
 
Method Summary
 double[][] getPhi()
          Retrieve estimated topic--word associations.
 LdaMarkovState getState()
          Get the current state of the markov chain.
 double[][] getTheta()
          Retrieve estimated document--topic associations.
protected  void gibbs()
          Main method: Select initial state ?
 int[][] gibbs(int[][] w, int V, int[][] z, int K, double alpha, double beta, int iter)
          Native implementation of the Gibbs sampling procedure.
 int[][] gibbsHeap(int[][] w, int[][] z, int[][] nw, int[] nwsum, int[][] nd, int[] ndsum, double alpha, double beta, int iter)
          Native gibbs sampling on the jvm heap.
 void gibbsHeap(LdaMarkovState s, ExtLdaConfiguration c)
          Native gibbs sampling on the jvm heap
protected  void initialState()
          Initialisation: Random assignments with equal probabilities
static java.lang.Object load(java.lang.String filename)
          read object from the stream
static void main(java.lang.String[] args)
           
 void output(int i)
          Handle output during sampling
 void run()
          Run the sampler after initialisation.
protected  void sampleCorpus(LdaMarkovState s)
          Sample once through the corpus and update the corresponding state.
protected  int sampleLdaFullConditional(LdaMarkovState s, int m, int n)
          Sample a topic z_i from the full conditional distribution: p(z_i = j | z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) + alpha)/(n_-i,.
 void save(java.lang.String filename)
          Object stream only for testing.
 void saveState(java.lang.String file)
          Saves the current state of the markov chain and the parameters to a file.
protected  void updateParams()
          Add to the statistics the values of theta and phi for the current state.
protected  void updatePhi()
          Update the topic--term association.
protected  void updateTheta()
          Update the document--topic associations.
protected  void writeParameters(java.lang.String file, org.knowceans.util.Arguments a, ITermCorpus corpus)
          write statistics of the current run to a text file for later review
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

serialVersionUID

private static final long serialVersionUID
See Also:
Constant Field Values

conf

protected ExtLdaConfiguration conf
Configuration object with the current parameters.


state

protected LdaMarkovState state
State variables of the Lda gibbs sampler.


thetasum

protected double[][] thetasum
cumulative statistics of theta


phisum

protected double[][] phisum
cumulative statistics of phi


backupIteration

public int backupIteration
iteration in the last backup


numstats

protected int numstats
size of statistics


dispcol

public int dispcol

rand

protected java.util.Random rand
Random generator

Constructor Detail

LdaGibbsSampler

protected LdaGibbsSampler()
For subclasses who know what they do...


LdaGibbsSampler

public LdaGibbsSampler(int[][] documents,
                       int V,
                       double alpha,
                       double beta,
                       int K,
                       int iterations)
Initialise the Gibbs sampler with data and standard values. (For backwards compatibility).

Parameters:
documents -
V -
alpha -
beta -
K -
iterations -

LdaGibbsSampler

public LdaGibbsSampler(int[][] documents,
                       int V,
                       ExtLdaConfiguration conf)
Initialise the sampler with the documents and the configuration.

Parameters:
documents -
V -
conf -

LdaGibbsSampler

public LdaGibbsSampler(int[][] documents,
                       int V,
                       ExtLdaConfiguration conf,
                       java.util.Random rand)
Initialise the sampler with the documents and the configuration.

Parameters:
documents -
V -
conf -
rand -

LdaGibbsSampler

public LdaGibbsSampler(ITermCorpus corpus,
                       ExtLdaConfiguration conf)
Initialise the corpus with

Parameters:
corpus -
conf -

LdaGibbsSampler

public LdaGibbsSampler(ITermCorpus corpus,
                       ExtLdaConfiguration conf,
                       java.util.Random rand)
Initialise the corpus with

Parameters:
corpus -
conf -
rand -

LdaGibbsSampler

protected LdaGibbsSampler(LdaMarkovState state,
                          ExtLdaConfiguration conf,
                          java.util.Random rand)
Initialise the sampler with an existing state.

Parameters:
corpus -
conf -
rand -

LdaGibbsSampler

public LdaGibbsSampler(ITermCorpus corpus,
                       LdaMarkovState state,
                       ExtLdaConfiguration conf,
                       java.util.Random rand)
Initialise the sampler with an existing state.

Parameters:
corpus -
state -
conf -
rand -
Method Detail

initialState

protected void initialState()
Initialisation: Random assignments with equal probabilities


run

public void run()
Run the sampler after initialisation.


gibbs

protected void gibbs()
Main method: Select initial state ? Repeat a large number of times: 1. Select an element 2. Update conditional on other elements. If appropriate, output summary for each run.


gibbs

public int[][] gibbs(int[][] w,
                     int V,
                     int[][] z,
                     int K,
                     double alpha,
                     double beta,
                     int iter)
Native implementation of the Gibbs sampling procedure. In the same class to allow subclass access.

Parameters:
w - words
V - vocabulary size
z - topic associations
K - topic count
alpha -
beta -
iter - number of iterations
Returns:
the new assignments z

gibbsHeap

public int[][] gibbsHeap(int[][] w,
                         int[][] z,
                         int[][] nw,
                         int[] nwsum,
                         int[][] nd,
                         int[] ndsum,
                         double alpha,
                         double beta,
                         int iter)
Native gibbs sampling on the jvm heap.

Parameters:
w - [in] words
z - [in/out] topic associations
nw - [in/out] topic-word counts
nwsum - [in/out] summed topic-word counts (total words per topic)
nd - [in/out] document-topic counts (total words per document)
ndsum - [in] document lengths
alpha -
beta -
iter -
Returns:

gibbsHeap

public void gibbsHeap(LdaMarkovState s,
                      ExtLdaConfiguration c)
Native gibbs sampling on the jvm heap

Parameters:
s - [in/out] state
c - [in] configuration

output

public void output(int i)
Handle output during sampling

Parameters:
i -

saveState

public void saveState(java.lang.String file)
Saves the current state of the markov chain and the parameters to a file.

Parameters:
file -

sampleCorpus

protected void sampleCorpus(LdaMarkovState s)
Sample once through the corpus and update the corresponding state. The parameter is used to choose the state to be sampled from: query vs. corpus; choice of chain for multichain sampling.

Parameters:
s -

sampleLdaFullConditional

protected int sampleLdaFullConditional(LdaMarkovState s,
                                       int m,
                                       int n)
Sample a topic z_i from the full conditional distribution: p(z_i = j | z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) + alpha)/(n_-i,.(d_i) + K * alpha)

Parameters:
m - document
n - word

updateParams

protected void updateParams()
Add to the statistics the values of theta and phi for the current state.


updateTheta

protected void updateTheta()
Update the document--topic associations.


updatePhi

protected void updatePhi()
Update the topic--term association.


getTheta

public double[][] getTheta()
Retrieve estimated document--topic associations. If sample lag > 0 then the mean value of all sampled statistics for theta[][] is taken.

Returns:
theta multinomial mixture of document topics (M x K)

getPhi

public double[][] getPhi()
Retrieve estimated topic--word associations. If sample lag > 0 then the mean value of all sampled statistics for phi[][] is taken.

Returns:
phi multinomial mixture of topic words (K x V)

main

public static void main(java.lang.String[] args)

save

public void save(java.lang.String filename)
Object stream only for testing.


load

public static java.lang.Object load(java.lang.String filename)
read object from the stream

Parameters:
filename -
Returns:

writeParameters

protected void writeParameters(java.lang.String file,
                               org.knowceans.util.Arguments a,
                               ITermCorpus corpus)
write statistics of the current run to a text file for later review

Parameters:
file -
a - Arguments object

getState

public LdaMarkovState getState()
Get the current state of the markov chain.

Returns: