/*
* Hpam2p.java -- HPAM2 Gibbs sampling implementation
* Copyright (C) 2009-2011 Gregor Heinrich, gregor :: arbylon : net
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see
* Implementation using parallelised samplers. *
* Mixture network specification: * *
* m >> (theta[m] | alpha) >> x[m][n]
* m, x[m][n] >> (thetax[m,x] | alphax[x]) >> y[m][n]
* x[m][n], y[m][n] >> (phi[k] | beta) >> w[m][n]
*
* k: {
* if (x==0) k = 0;
* else if (y==0) k = 1 + x;
* else k = 1 + X + y;
* }.
*
* elements:
*
* document m
* type: {ROOT|E3COUPLED}
* parents: (root)
* children: theta thetax
* range: M
* ----
* doc-suptop (theta[m] | alpha)
* type: {SEQUENCE|C1ROOT}
* parents: m
* children: x
* components: theta, domain: M, range: X
* counts: nmx, sum: null
* index: m, selector: null
* hyperparams: alpha, dimension X, fixed: false
* ----
* top-subtop (thetax[m,x] | alphax[x])
* type: {SEQUENCE|C1BSEQADD}
* parents: m x
* children: y
* components: thetax, domain: M * X, range: Y
* counts: nmxy, sum: nmxysum
* index: mx, selector: m,x
* hyperparams: alphax, dimension: X * Y, selector: x, fixed: false
* ----
* suptopic x[m][n]
* type: {HIDDEN|E3COUPLED}
* parents: theta
* children: thetax phi
* range: X
* ----
* subtopic y[m][n]
* type: {HIDDEN|E1SINGLE}
* parents: thetax
* children: phi
* range: Y
* ----
* hiertop-word (phi[k] | beta)
* type: {TOPIC|C2MULTI}
* parents: x y
* children: w
* components: phi, domain: 1 + X + Y, range: V
* counts: nkw, sum: nkwsum
* index: k, selector:
* if (x==0) k = 0;
* else if (y==0) k = 1 + x;
* else k = 1 + X + y;
* hyperparams: beta, dimension 1, fixed: false
* ----
* word w[m][n]
* type: {VISIBLE|E1SINGLE}
* parents: phi
* children: (leaf)
* range: V
* ----
* sequences:
*
* words [m][n]
* parent: (root), children: []
* edges: m x y w
*
*
* @author Gregor Heinrich, gregor :: arbylon : net (via MixNetKernelGenerator)
* @date generated on 11 Mar 2011
*/
// constructor
public class Hpam2pGibbsSampler {
// //////////////// fields //////////////////
// fields
Random[] rand;
int iter;
int niter;
// root edge: document
int M;
int Mq;
// sequence node: doc-suptop
int[][] nmx;
int[][] nmxq;
double[] alpha;
double alphasum;// sequence node: top-subtop
int[][] nmxy;
int[][] nmxyq;
int[] nmxysum;
int[] nmxysumq;
double[][] alphax;
double[] alphaxsum;// hidden edge: suptopic
int[][] x;
int[][] xq;
int X;
// hidden edge: subtopic
int[][] y;
int[][] yq;
int Y;
// topic node: hiertop-word
int[][] nkw;
int[] nkwsum;
double beta;
double betasum;
double[][] phi;
// visible edge: word
int[][] w;
int[][] wq;
int V;
// sequence: words
int W;
int Wq;
// sampling weights
double[][][] pp;
int P;
// //////////////// main //////////////////
// standard main routine
public static void main(String[] args) {
String filebase = "nips/nips";
Random rand = new CokusRandom();
// set up corpus
LabelNumCorpus corpus = new LabelNumCorpus(filebase);
corpus.split(10, 1, rand);
LabelNumCorpus train = (LabelNumCorpus) corpus.getTrainCorpus();
LabelNumCorpus test = (LabelNumCorpus) corpus.getTestCorpus();
// TODO: adjust data source
int[][] w = train.getDocWords(rand);
int[][] wq = test.getDocWords(rand);
int V = corpus.getNumTerms();
// parameters
int X = 20;
int Y = 20;
double alpha = 0.1;
double alphax = 0.1;
double beta = 0.1;
int niter = 100, niterq = 10;
// create sampler
Hpam2pGibbsSampler gs = new Hpam2pGibbsSampler(alpha, alphax, X, Y,
beta, w, wq, V, rand);
gs.init();
System.out.println(gs);
// initial test
gs.initq();
gs.runq(niterq);
System.out.println(gs.ppx());
// run sampler
StopWatch.start();
gs.init();
gs.run(niter);
System.out.println(StopWatch.format(StopWatch.stop()));
// final test
gs.initq();
gs.runq(niterq);
System.out.println(gs.ppx());
System.out.println(gs);
} // main
// //////////////// constructor //////////////////
public Hpam2pGibbsSampler(double alpha, double alphax, int X, int Y,
double beta, int[][] w, int[][] wq, int V, Random rand) {
// assign
this.X = X;
this.Y = Y;
this.w = w;
this.wq = wq;
this.V = V;
this.alpha = new double[X];
for (int i = 0; i < X; i++) {
this.alpha[i] = alpha;
}
this.alphasum = X * alpha;
this.alphax = new double[X][Y];
this.alphaxsum = new double[X];
for (int mxj = 0; mxj < X; mxj++) {
for (int t = 0; t < Y; t++) {
this.alphax[mxj][t] = alphax;
} // for t
this.alphaxsum[mxj] = Y * alphax;
} // for mxj
this.beta = beta;
this.betasum = V * beta;
// count tokens
M = w.length;
W = 0;
for (int m = 0; m < M; m++) {
W += w[m].length;
}
Mq = wq.length;
Wq = 0;
for (int m = 0; m < Mq; m++) {
Wq += wq[m].length;
}
P = Runtime.getRuntime().availableProcessors();
this.rand = new CokusRandom[P];
// all processors
for (int p = 0; p < P; p++) {
this.rand[p] = new CokusRandom(rand.nextInt());
} // for p
// allocate sampling weights
pp = new double[P][X][Y];
} // c'tor
// //////////////// initialisation //////////////////
// initialisation
public void init() {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int ksel = -1;
// sequence node
nmx = new int[M][X];
// sequence node
nmxy = new int[M * X][Y];
nmxysum = new int[M * X];
// hidden edge
x = new int[M][];
for (int m = 0; m < M; m++) {
x[m] = new int[w[m].length];
}
// hidden edge
y = new int[M][];
for (int m = 0; m < M; m++) {
y[m] = new int[w[m].length];
}
// topic node
nkw = new int[1 + X + Y][V];
nkwsum = new int[1 + X + Y];
// initialise randomly
// major loop, sequence [m][n]
for (int m = 0; m < M; m++) {
// minor loop, sequence [m][n]
for (int n = 0; n < w[m].length; n++) {
// sample edge values
int hx = rand[0].nextInt(X);
int hy = rand[0].nextInt(Y);
// assign topics
x[m][n] = hx;
y[m][n] = hy;
// increment counts
nmx[m][hx]++;
mxsel = X * m + hx;
nmxy[mxsel][hy]++;
nmxysum[mxsel]++;
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
synchronized (nkw[ksel]) {
nkw[ksel][w[m][n]]++;
nkwsum[ksel]++;
} // sync
} // for n
} // for m
} // init
// //////////////// query initialisation //////////////////
// initialisation
public void initq() {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int ksel = -1;
// sequence node
nmxq = new int[Mq][X];
// sequence node
nmxyq = new int[Mq * X][Y];
nmxysumq = new int[Mq * X];
// hidden edge
xq = new int[Mq][];
for (int m = 0; m < Mq; m++) {
xq[m] = new int[wq[m].length];
}
// hidden edge
yq = new int[Mq][];
for (int m = 0; m < Mq; m++) {
yq[m] = new int[wq[m].length];
}
// topic node
// compute parameters
phi = new double[1 + X + Y][V];
for (int hx = 0; hx < X; hx++) {
for (int hy = 0; hy < Y; hy++) {
for (int t = 0; t < V; t++) {
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
phi[ksel][t] = (nkw[ksel][t] + beta)
/ (nkwsum[ksel] + betasum);
} // t
} // h
} // h
// initialise randomly
// major loop, sequence [m][n]
for (int m = 0; m < Mq; m++) {
// minor loop, sequence [m][n]
for (int n = 0; n < wq[m].length; n++) {
// sample edge values
int hx = rand[0].nextInt(X);
int hy = rand[0].nextInt(Y);
// assign topics
xq[m][n] = hx;
yq[m][n] = hy;
// increment counts
nmxq[m][hx]++;
mxsel = X * m + hx;
nmxyq[mxsel][hy]++;
nmxysumq[mxsel]++;
} // for n
} // for m
} // initq
// //////////////// main kernel //////////////////
// Gibbs kernel
public void run(int niter) {
// for each document
// parallel foreach m < M
ParallelFor parfor = new ParallelFor(P) {
@Override
public void process(int m, int thread) {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int ksel = -1;
// minor loop, sequence [m][n]
for (int n = 0; n < w[m].length; n++) {
double psum;
double u;
// assign topics
int hx = x[m][n];
int hy = y[m][n];
// decrement counts
nmx[m][hx]--;
mxsel = X * m + hx;
nmxy[mxsel][hy]--;
nmxysum[mxsel]--;
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
synchronized (nkw[ksel]) {
nkw[ksel][w[m][n]]--;
nkwsum[ksel]--;
} // sync
// compute weights
/*
* &p(x_{m,n} \eq x, y_{m,n} \eq y\; |\;\vec x_{-m,n}, \vec
* y_{-m,n}, \vec w, \cdot) \notag\\ &\qquad\propto
* (n^{-mn}_{m,x} + \alpha_{x} ) \cdot \frac{n^{-mn}_{mx,y}
* + \alpha^\text{x}_{x, y} }{\sum_{y} n^{-mn}_{mx,y} +
* \alpha^\text{x}_{x, y}} \cdot \frac{n^{-mn}_{k,w} + \beta
* }{\sum_{w} n^{-mn}_{k,w} + \beta}
*/
psum = 0;
// hidden edge
for (hx = 0; hx < X; hx++) {
// hidden edge
for (hy = 0; hy < Y; hy++) {
mxsel = X * m + hx;
mxjsel = hx;
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
pp[thread][hx][hy] = (nmx[m][hx] + alpha[hx])
* (nmxy[mxsel][hy] + alphax[mxjsel][hy])
/ (nmxysum[mxsel] + alphaxsum[mxjsel])
* (nkw[ksel][w[m][n]] + beta)
/ (nkwsum[ksel] + betasum);
psum += pp[thread][hx][hy];
} // for h
} // for h
// sample topics
u = rand[thread].nextDouble() * psum;
psum = 0;
SAMPLED:
// each edge value
for (hx = 0; hx < X; hx++) {
// each edge value
for (hy = 0; hy < Y; hy++) {
psum += pp[thread][hx][hy];
if (u <= psum)
break SAMPLED;
} // h
} // h
// assign topics
x[m][n] = hx;
y[m][n] = hy;
// increment counts
nmx[m][hx]++;
mxsel = X * m + hx;
nmxy[mxsel][hy]++;
nmxysum[mxsel]++;
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
synchronized (nkw[ksel]) {
nkw[ksel][w[m][n]]++;
nkwsum[ksel]++;
} // sync
} // for n
} // process
}; // foreach m
// Gibbs iterations
for (int iter = 0; iter < niter; iter++) {
System.out.println(iter);
// parallel loop
parfor.loop(M);
// estimate hyperparameters
estAlpha();
} // for iter
parfor.shutdown();
} // for run
// //////////////// query kernel //////////////////
// Gibbs kernel
public void runq(int niterq) {
// for each document
// parallel foreach m < Mq
ParallelFor parfor = new ParallelFor(P) {
@Override
public void process(int m, int thread) {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int ksel = -1;
// minor loop, sequence [m][n]
for (int n = 0; n < wq[m].length; n++) {
double psum;
double u;
// assign topics
int hx = xq[m][n];
int hy = yq[m][n];
// decrement counts
nmxq[m][hx]--;
mxsel = X * m + hx;
nmxyq[mxsel][hy]--;
nmxysumq[mxsel]--;
// compute weights
/*
* &p(x_{m,n} \eq x, y_{m,n} \eq y\; |\;\vec x_{-m,n}, \vec
* y_{-m,n}, \vec w, \cdot) \notag\\ &\qquad\propto
* (n^{-mn}_{m,x} + \alpha_{x} ) \cdot \frac{n^{-mn}_{mx,y}
* + \alpha^\text{x}_{x, y} }{\sum_{y} n^{-mn}_{mx,y} +
* \alpha^\text{x}_{x, y}} \cdot \phi_{k,w}
*/
psum = 0;
// hidden edge
for (hx = 0; hx < X; hx++) {
// hidden edge
for (hy = 0; hy < Y; hy++) {
mxsel = X * m + hx;
mxjsel = hx;
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
pp[thread][hx][hy] = (nmxq[m][hx] + alpha[hx])
* (nmxyq[mxsel][hy] + alphax[mxjsel][hy])
/ (nmxysumq[mxsel] + alphaxsum[mxjsel])
* phi[ksel][wq[m][n]];
psum += pp[thread][hx][hy];
} // for h
} // for h
// sample topics
u = rand[thread].nextDouble() * psum;
psum = 0;
SAMPLED:
// each edge value
for (hx = 0; hx < X; hx++) {
// each edge value
for (hy = 0; hy < Y; hy++) {
psum += pp[thread][hx][hy];
if (u <= psum)
break SAMPLED;
} // h
} // h
// assign topics
xq[m][n] = hx;
yq[m][n] = hy;
// increment counts
nmxq[m][hx]++;
mxsel = X * m + hx;
nmxyq[mxsel][hy]++;
nmxysumq[mxsel]++;
} // for n
} // process
}; // foreach m
// Gibbs iterations
for (int iter = 0; iter < niterq; iter++) {
System.out.println(iter);
// parallel loop
parfor.loop(Mq);
} // for iter
parfor.shutdown();
} // for runq
// //////////////// hyperparameters //////////////////
// update hyperparameters
public void estAlpha() {
if (iter < 15) {
return;
}
// component selectors
int mxsel = -1;
int mxjsel = -1;
int ksel = -1;
// Note: assuming non-informative gamma priors (1,0)
// hyperparameter for theta
int[] nmxsum = new int[M];
// all components
for (int m = 0; m < M; m++) {
nmxsum[m] = w[m].length;
} // for m
double[] xalpha = DirichletEstimation.estimateAlphaMap(nmx, nmxsum,
alpha, 1., 0.);
if (alpha[0] < 2.) {
for (int t = 0; t < X; t++) {
alpha[t] = (xalpha[t] + alpha[t]) / 2;
} // for t
} // < 2
// hyperparameter for thetax
// filter nkt and nktsum though jSel index space.
int[] mx2j = new int[Mq * X];
// for parent values
for (int m = 0; m < M; m++) {
// for parent values
for (int hx = 0; hx < X; hx++) {
mxsel = X * m + hx;
mxjsel = hx;
mx2j[mxsel] = mxjsel;
} // for h
} // for h
double[][] xalphax = DirichletEstimation.estimateAlphaMapSub(nmxy,
nmxysum, mx2j, alphax, 1., 0.);
if (alphax[0][0] < 2.) {
// all component groups
for (int j = 0; j < X; j++) {
// all values
for (int t = 0; t < Y; t++) {
alphax[j][t] = (xalphax[j][t] + alphax[j][t]) / 2;
} // for t
} // for j
} // < 2
// hyperparameter for phi
double xbeta = DirichletEstimation.estimateAlphaMap(nkw, nkwsum, beta,
1., 0.);
if (beta < 2.) {
beta = (beta + xbeta) / 2;
} // < 2
} // updateHyper
// //////////////// perplexity //////////////////
// calculate perplexity value
public double ppx() {
double loglik = 0;
// component selectors
int mxsel = -1;
int mxjsel = -1;
int ksel = -1;
// compute sequence node parameters
double[][] thetaq = new double[Mq][X];
// for parent values
for (int m = 0; m < Mq; m++) {
int nmxmq = 0;
for (int x = 0; x < X; x++) {
nmxmq += nmxq[m][x];
} // for x
for (int x = 0; x < X; x++) {
thetaq[m][x] = (nmxq[m][x] + alpha[x]) / (nmxmq + alphasum);
} // for x
} // for parent h
double[][] thetaxq = new double[Mq * X][Y];
// for parent values
for (int m = 0; m < Mq; m++) {
// for parent values
for (int hx = 0; hx < X; hx++) {
for (int y = 0; y < Y; y++) {
mxsel = X * m + hx;
mxjsel = hx;
thetaxq[mxsel][y] = (nmxyq[mxsel][y] + alphax[mxjsel][y])
/ (nmxysumq[mxsel] + alphaxsum[mxjsel]);
} // for y
} // for parent h
} // for parent h
// compute ppx
// major loop, sequence [m][n]
for (int m = 0; m < Mq; m++) {
// minor loop, sequence [m][n]
for (int n = 0; n < wq[m].length; n++) {
double sum = 0;
// hidden edge
for (int hx = 0; hx < X; hx++) {
// hidden edge
for (int hy = 0; hy < Y; hy++) {
mxsel = X * m + hx;
if (hx == 0)
ksel = 0;
else if (hy == 0)
ksel = 1 + hx;
else
ksel = 1 + X + hy;
sum += thetaq[m][hx] * thetaxq[mxsel][hy]
* phi[ksel][wq[m][n]];
} // for h
} // for h
loglik += Math.log(sum);
} // for n
} // for m
return Math.exp(-loglik / Wq);
} // ppx
// //////////////// monitor string //////////////////
// describe class and parameters
public String toString() {
return "HPAM2:\nm >> (theta[m] | alpha) >> x[m][n]\n\t"
+ "m, x[m][n] >> (thetax[m,x] | alphax[x]) >> y[m][n]\n\t"
+ "x[m][n], y[m][n] >> (phi[k] | beta) >> w[m][n]\n\t"
+ "\n\t"
+ "k: { \n\t"
+ "if (x==0) k = 0;\n\t"
+ "else if (y==0) k = 1 + x;\n\t"
+ "else k = 1 + X + y; \n\t"
+ "}. \n\t"
+ "\n"
+ String.format(
"Hpam2pGibbsSampler: P = %d\n"
+ "M = %d Mq = %d W = %d Wq = %d \n"
+ "X = %d Y = %d V = %d \n"
+ "alpha[0] = %2.5f alphax[0][0] = %2.5f beta = %2.5f ",
P, M, Mq, W, Wq, X, Y, V, alpha[0], alphax[0][0], beta);
}
} // Hpam2pGibbsSampler