/*
* Hpam1.java -- HPAM1 Gibbs sampling implementation
* Copyright (C) 2009-2011 Gregor Heinrich, gregor :: arbylon : net
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see
* m >> (theta[m] | alpha) >> x[m][n]
* m, x[m][n] >> (thetax[m,x] | alphax[x]) >> y[m][n]
* x[m][n], y[m][n] >> (zeta[x,y] | gamma) >> l[m][n]
* x[m][n], y[m][n], l[m][n] >> (phi[k] | beta) >> w[m][n]
*
* k: {
* if (l==0) k = 0;
* else if (l==1) k = 1 + x;
* else if (l==2) k = 1 + X + y;
* }.
*
* elements:
*
* document m
* type: {ROOT|E3COUPLED}
* parents: (root)
* children: theta thetax
* range: M
* ----
* doc-suptop (theta[m] | alpha)
* type: {SEQUENCE|C1ROOT}
* parents: m
* children: x
* components: theta, domain: M, range: X
* counts: nmx, sum: null
* index: m, selector: null
* hyperparams: alpha, dimension X, fixed: false
* ----
* top-subtop (thetax[m,x] | alphax[x])
* type: {SEQUENCE|C1BSEQADD}
* parents: m x
* children: y
* components: thetax, domain: M * X, range: Y
* counts: nmxy, sum: nmxysum
* index: mx, selector: m,x
* hyperparams: alphax, dimension: X * Y, selector: x, fixed: false
* ----
* suptopic x[m][n]
* type: {HIDDEN|E3COUPLED}
* parents: theta
* children: thetax zeta phi
* range: X
* ----
* subtopic y[m][n]
* type: {HIDDEN|E3COUPLED}
* parents: thetax
* children: zeta phi
* range: Y
* ----
* toptop-level (zeta[x,y] | gamma)
* type: {TOPIC|C2MULTI}
* parents: x y
* children: l
* components: zeta, domain: X * Y, range: L
* counts: nxyl, sum: nxylsum
* index: xy, selector: x,y
* hyperparams: gamma, dimension 1, fixed: false
* ----
* hiertop-word (phi[k] | beta)
* type: {TOPIC|C2MULTI}
* parents: x y l
* children: w
* components: phi, domain: 1 + X + Y, range: V
* counts: nkw, sum: nkwsum
* index: k, selector:
* if (l==0) k = 0;
* else if (l==1) k = 1 + x;
* else if (l==2) k = 1 + X + y;
* hyperparams: beta, dimension 1, fixed: false
* ----
* level l[m][n]
* type: {HIDDEN|QFIXED|E1SINGLE}
* parents: zeta
* children: phi
* range: L
* ----
* word w[m][n]
* type: {VISIBLE|E1SINGLE}
* parents: phi
* children: (leaf)
* range: V
* ----
* sequences:
*
* words [m][n]
* parent: (root), children: []
* edges: m x y l w
*
*
* @author Gregor Heinrich, gregor :: arbylon : net (via MixNetKernelGenerator)
* @date generated on 11 Mar 2011
*/
// constructor
public class Hpam1GibbsSampler {
// //////////////// fields //////////////////
// fields
Random rand;
int iter;
int niter;
// root edge: document
int M;
int Mq;
// sequence node: doc-suptop
int[][] nmx;
int[][] nmxq;
double[] alpha;
double alphasum;// sequence node: top-subtop
int[][] nmxy;
int[][] nmxyq;
int[] nmxysum;
int[] nmxysumq;
double[][] alphax;
double[] alphaxsum;// hidden edge: suptopic
int[][] x;
int[][] xq;
int X;
// hidden edge: subtopic
int[][] y;
int[][] yq;
int Y;
// topic node: toptop-level
int[][] nxyl;
int[] nxylsum;
double gamma;
double gammasum;
double[][] zeta;
// topic node: hiertop-word
int[][] nkw;
int[] nkwsum;
double beta;
double betasum;
double[][] phi;
// hidden edge: level
int[][] l;
int[][] lq;
int L;
// visible edge: word
int[][] w;
int[][] wq;
int V;
// sequence: words
int W;
int Wq;
// sampling weights
double[][][] pp;
// //////////////// main //////////////////
// standard main routine
public static void main(String[] args) {
String filebase = "nips/nips";
Random rand = new CokusRandom();
// set up corpus
LabelNumCorpus corpus = new LabelNumCorpus(filebase);
corpus.split(10, 1, rand);
LabelNumCorpus train = (LabelNumCorpus) corpus.getTrainCorpus();
LabelNumCorpus test = (LabelNumCorpus) corpus.getTestCorpus();
// TODO: adjust data source
int[][] w = train.getDocWords(rand);
int[][] wq = test.getDocWords(rand);
int V = corpus.getNumTerms();
// parameters
int X = 20;
int Y = 20;
double alpha = 0.1;
double alphax = 0.1;
double gamma = 0.1;
double beta = 0.1;
int niter = 20, niterq = 10;
// create sampler
Hpam1GibbsSampler gs = new Hpam1GibbsSampler(alpha, alphax, X, Y,
gamma, beta, w, wq, V, rand);
gs.init();
System.out.println(gs);
// initial test
gs.initq();
gs.runq(niterq);
System.out.println(gs.ppx());
// run sampler
StopWatch.start();
gs.init();
gs.run(niter);
System.out.println(StopWatch.format(StopWatch.stop()));
// final test
gs.initq();
gs.runq(niterq);
System.out.println(gs.ppx());
System.out.println(gs);
} // main
// //////////////// constructor //////////////////
public Hpam1GibbsSampler(double alpha, double alphax, int X, int Y,
double gamma, double beta, int[][] w, int[][] wq, int V, Random rand) {
// assign
this.X = X;
this.Y = Y;
this.w = w;
this.wq = wq;
this.V = V;
this.alpha = new double[X];
for (int i = 0; i < X; i++) {
this.alpha[i] = alpha;
}
this.alphasum = X * alpha;
this.alphax = new double[X][Y];
this.alphaxsum = new double[X];
for (int mxj = 0; mxj < X; mxj++) {
for (int t = 0; t < Y; t++) {
this.alphax[mxj][t] = alphax;
} // for t
this.alphaxsum[mxj] = Y * alphax;
} // for mxj
this.gamma = gamma;
this.gammasum = L * gamma;
this.beta = beta;
this.betasum = V * beta;
// constants
this.L = 3;
// count tokens
M = w.length;
W = 0;
for (int m = 0; m < M; m++) {
W += w[m].length;
}
Mq = wq.length;
Wq = 0;
for (int m = 0; m < Mq; m++) {
Wq += wq[m].length;
}
this.rand = rand;
// allocate sampling weights
pp = new double[X][Y][L];
} // c'tor
// //////////////// initialisation //////////////////
// initialisation
public void init() {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int xysel = -1;
int ksel = -1;
// sequence node
nmx = new int[M][X];
// sequence node
nmxy = new int[M * X][Y];
nmxysum = new int[M * X];
// hidden edge
x = new int[M][];
for (int m = 0; m < M; m++) {
x[m] = new int[w[m].length];
}
// hidden edge
y = new int[M][];
for (int m = 0; m < M; m++) {
y[m] = new int[w[m].length];
}
// topic node
nxyl = new int[X * Y][L];
nxylsum = new int[X * Y];
// topic node
nkw = new int[1 + X + Y][V];
nkwsum = new int[1 + X + Y];
// hidden edge
l = new int[M][];
for (int m = 0; m < M; m++) {
l[m] = new int[w[m].length];
}
// initialise randomly
// major loop, sequence [m][n]
for (int m = 0; m < M; m++) {
// minor loop, sequence [m][n]
for (int n = 0; n < w[m].length; n++) {
// sample edge values
int hx = rand.nextInt(X);
int hy = rand.nextInt(Y);
int hl = rand.nextInt(L);
// assign topics
x[m][n] = hx;
y[m][n] = hy;
l[m][n] = hl;
// increment counts
nmx[m][hx]++;
mxsel = X * m + hx;
nmxy[mxsel][hy]++;
nmxysum[mxsel]++;
xysel = Y * hx + hy;
nxyl[xysel][hl]++;
nxylsum[xysel]++;
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
nkw[ksel][w[m][n]]++;
nkwsum[ksel]++;
} // for n
} // for m
} // init
// //////////////// query initialisation //////////////////
// initialisation
public void initq() {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int xysel = -1;
int ksel = -1;
// sequence node
nmxq = new int[Mq][X];
// sequence node
nmxyq = new int[Mq * X][Y];
nmxysumq = new int[Mq * X];
// hidden edge
xq = new int[Mq][];
for (int m = 0; m < Mq; m++) {
xq[m] = new int[wq[m].length];
}
// hidden edge
yq = new int[Mq][];
for (int m = 0; m < Mq; m++) {
yq[m] = new int[wq[m].length];
}
// topic node
// compute parameters
zeta = new double[X * Y][L];
for (int hx = 0; hx < X; hx++) {
for (int hy = 0; hy < Y; hy++) {
for (int t = 0; t < L; t++) {
xysel = Y * hx + hy;
zeta[xysel][t] = (nxyl[xysel][t] + gamma)
/ (nxylsum[xysel] + gammasum);
} // t
} // h
} // h
// topic node
// compute parameters
phi = new double[1 + X + Y][V];
for (int hx = 0; hx < X; hx++) {
for (int hy = 0; hy < Y; hy++) {
for (int hl = 0; hl < L; hl++) {
for (int t = 0; t < V; t++) {
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
phi[ksel][t] = (nkw[ksel][t] + beta)
/ (nkwsum[ksel] + betasum);
} // t
} // h
} // h
} // h
// hidden edge
lq = new int[Mq][];
for (int m = 0; m < Mq; m++) {
lq[m] = new int[wq[m].length];
}
// initialise randomly
// major loop, sequence [m][n]
for (int m = 0; m < Mq; m++) {
// minor loop, sequence [m][n]
for (int n = 0; n < wq[m].length; n++) {
// sample edge values
int hx = rand.nextInt(X);
int hy = rand.nextInt(Y);
// assign topics
xq[m][n] = hx;
yq[m][n] = hy;
// increment counts
nmxq[m][hx]++;
mxsel = X * m + hx;
nmxyq[mxsel][hy]++;
nmxysumq[mxsel]++;
} // for n
} // for m
} // initq
// //////////////// main kernel //////////////////
// Gibbs kernel
public void run(int niter) {
// iteration loop
for (int iter = 0; iter < niter; iter++) {
System.out.println(iter);
// major loop, sequence [m][n]
for (int m = 0; m < M; m++) {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int xysel = -1;
int ksel = -1;
// minor loop, sequence [m][n]
for (int n = 0; n < w[m].length; n++) {
double psum;
double u;
// assign topics
int hx = x[m][n];
int hy = y[m][n];
int hl = l[m][n];
// decrement counts
nmx[m][hx]--;
mxsel = X * m + hx;
nmxy[mxsel][hy]--;
nmxysum[mxsel]--;
xysel = Y * hx + hy;
nxyl[xysel][hl]--;
nxylsum[xysel]--;
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
nkw[ksel][w[m][n]]--;
nkwsum[ksel]--;
// compute weights
/*
* &p(x_{m,n} \eq x, y_{m,n} \eq y, l_{m,n} \eq l\; |\;\vec
* x_{-m,n}, \vec y_{-m,n}, \vec l_{-m,n}, \vec w, \cdot)
* \notag\\ &\qquad\propto (n^{-mn}_{m,x} + \alpha_{x} )
* \cdot \frac{n^{-mn}_{mx,y} + \alpha^\text{x}_{x, y}
* }{\sum_{y} n^{-mn}_{mx,y} + \alpha^\text{x}_{x, y}} \cdot
* \frac{n^{-mn}_{xy,l} + \gamma }{\sum_{l} n^{-mn}_{xy,l} +
* \gamma} \cdot \frac{n^{-mn}_{k,w} + \beta }{\sum_{w}
* n^{-mn}_{k,w} + \beta}
*/
psum = 0;
// hidden edge
for (hx = 0; hx < X; hx++) {
// hidden edge
for (hy = 0; hy < Y; hy++) {
// hidden edge
for (hl = 0; hl < L; hl++) {
mxsel = X * m + hx;
mxjsel = hx;
xysel = Y * hx + hy;
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
pp[hx][hy][hl] = (nmx[m][hx] + alpha[hx])
* (nmxy[mxsel][hy] + alphax[mxjsel][hy])
/ (nmxysum[mxsel] + alphaxsum[mxjsel])
* (nxyl[xysel][hl] + gamma)
/ (nxylsum[xysel] + gammasum)
* (nkw[ksel][w[m][n]] + beta)
/ (nkwsum[ksel] + betasum);
psum += pp[hx][hy][hl];
} // for h
} // for h
} // for h
// sample topics
u = rand.nextDouble() * psum;
psum = 0;
SAMPLED:
// each edge value
for (hx = 0; hx < X; hx++) {
// each edge value
for (hy = 0; hy < Y; hy++) {
// each edge value
for (hl = 0; hl < L; hl++) {
psum += pp[hx][hy][hl];
if (u <= psum)
break SAMPLED;
} // h
} // h
} // h
// assign topics
x[m][n] = hx;
y[m][n] = hy;
l[m][n] = hl;
// increment counts
nmx[m][hx]++;
mxsel = X * m + hx;
nmxy[mxsel][hy]++;
nmxysum[mxsel]++;
xysel = Y * hx + hy;
nxyl[xysel][hl]++;
nxylsum[xysel]++;
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
nkw[ksel][w[m][n]]++;
nkwsum[ksel]++;
} // for n
} // for m
// estimate hyperparameters
estAlpha();
} // for iter
} // for run
// //////////////// query kernel //////////////////
// Gibbs kernel
public void runq(int niterq) {
// iteration loop
for (int iter = 0; iter < niterq; iter++) {
System.out.println(iter);
// major loop, sequence [m][n]
for (int m = 0; m < Mq; m++) {
// component selectors
int mxsel = -1;
int mxjsel = -1;
int xysel = -1;
int ksel = -1;
// minor loop, sequence [m][n]
for (int n = 0; n < wq[m].length; n++) {
double psum;
double u;
// assign topics
int hx = xq[m][n];
int hy = yq[m][n];
int hl;
// decrement counts
nmxq[m][hx]--;
mxsel = X * m + hx;
nmxyq[mxsel][hy]--;
nmxysumq[mxsel]--;
// compute weights
/*
* &p(x_{m,n} \eq x, y_{m,n} \eq y\; |\;\vec x_{-m,n}, \vec
* y_{-m,n}, \vec w, \cdot) \notag\\ &\qquad\propto
* (n^{-mn}_{m,x} + \alpha_{x} ) \cdot \frac{n^{-mn}_{mx,y}
* + \alpha^\text{x}_{x, y} }{\sum_{y} n^{-mn}_{mx,y} +
* \alpha^\text{x}_{x, y}} \cdot \zeta_{xy,l} \cdot
* \phi_{k,w}
*/
psum = 0;
// hidden edge
for (hx = 0; hx < X; hx++) {
// hidden edge
for (hy = 0; hy < Y; hy++) {
// hidden edge
for (hl = 0; hl < L; hl++) {
mxsel = X * m + hx;
mxjsel = hx;
xysel = Y * hx + hy;
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
pp[hx][hy][hl] = (nmxq[m][hx] + alpha[hx])
* (nmxyq[mxsel][hy] + alphax[mxjsel][hy])
/ (nmxysumq[mxsel] + alphaxsum[mxjsel])
* zeta[xysel][hl] * phi[ksel][wq[m][n]];
psum += pp[hx][hy][hl];
} // for h
} // for h
} // for h
// sample topics
u = rand.nextDouble() * psum;
psum = 0;
SAMPLED:
// each edge value
for (hx = 0; hx < X; hx++) {
// each edge value
for (hy = 0; hy < Y; hy++) {
// each edge value
for (hl = 0; hl < L; hl++) {
psum += pp[hx][hy][hl];
if (u <= psum)
break SAMPLED;
} // h
} // h
} // h
// assign topics
xq[m][n] = hx;
yq[m][n] = hy;
// increment counts
nmxq[m][hx]++;
mxsel = X * m + hx;
nmxyq[mxsel][hy]++;
nmxysumq[mxsel]++;
} // for n
} // for m
} // for iter
} // for runq
// //////////////// hyperparameters //////////////////
// update hyperparameters
public void estAlpha() {
if (iter < 15) {
return;
}
// component selectors
int mxsel = -1;
int mxjsel = -1;
int xysel = -1;
int ksel = -1;
// Note: assuming non-informative gamma priors (1,0)
// hyperparameter for theta
int[] nmxsum = new int[M];
// all components
for (int m = 0; m < M; m++) {
nmxsum[m] = w[m].length;
} // for m
double[] xalpha = DirichletEstimation.estimateAlphaMap(nmx, nmxsum,
alpha, 1., 0.);
if (alpha[0] < 2.) {
for (int t = 0; t < X; t++) {
alpha[t] = (xalpha[t] + alpha[t]) / 2;
} // for t
} // < 2
// hyperparameter for thetax
// filter nkt and nktsum though jSel index space.
int[] mx2j = new int[Mq * X];
// for parent values
for (int m = 0; m < M; m++) {
// for parent values
for (int hx = 0; hx < X; hx++) {
mxsel = X * m + hx;
mxjsel = hx;
mx2j[mxsel] = mxjsel;
} // for h
} // for h
double[][] xalphax = DirichletEstimation.estimateAlphaMapSub(nmxy,
nmxysum, mx2j, alphax, 1., 0.);
if (alphax[0][0] < 2.) {
// all component groups
for (int j = 0; j < X; j++) {
// all values
for (int t = 0; t < Y; t++) {
alphax[j][t] = (xalphax[j][t] + alphax[j][t]) / 2;
} // for t
} // for j
} // < 2
// hyperparameter for zeta
double xgamma = DirichletEstimation.estimateAlphaMap(nxyl, nxylsum,
gamma, 1., 0.);
if (gamma < 2.) {
gamma = (gamma + xgamma) / 2;
} // < 2
// hyperparameter for phi
double xbeta = DirichletEstimation.estimateAlphaMap(nkw, nkwsum, beta,
1., 0.);
if (beta < 2.) {
beta = (beta + xbeta) / 2;
} // < 2
} // updateHyper
// //////////////// perplexity //////////////////
// calculate perplexity value
public double ppx() {
double loglik = 0;
// component selectors
int mxsel = -1;
int mxjsel = -1;
int xysel = -1;
int ksel = -1;
// compute sequence node parameters
double[][] thetaq = new double[Mq][X];
// for parent values
for (int m = 0; m < Mq; m++) {
int nmxmq = 0;
for (int x = 0; x < X; x++) {
nmxmq += nmxq[m][x];
} // for x
for (int x = 0; x < X; x++) {
thetaq[m][x] = (nmxq[m][x] + alpha[x]) / (nmxmq + alphasum);
} // for x
} // for parent h
double[][] thetaxq = new double[Mq * X][Y];
// for parent values
for (int m = 0; m < Mq; m++) {
// for parent values
for (int hx = 0; hx < X; hx++) {
for (int y = 0; y < Y; y++) {
mxsel = X * m + hx;
mxjsel = hx;
thetaxq[mxsel][y] = (nmxyq[mxsel][y] + alphax[mxjsel][y])
/ (nmxysumq[mxsel] + alphaxsum[mxjsel]);
} // for y
} // for parent h
} // for parent h
// compute ppx
// major loop, sequence [m][n]
for (int m = 0; m < Mq; m++) {
// minor loop, sequence [m][n]
for (int n = 0; n < wq[m].length; n++) {
double sum = 0;
// hidden edge
for (int hx = 0; hx < X; hx++) {
// hidden edge
for (int hy = 0; hy < Y; hy++) {
// hidden edge
for (int hl = 0; hl < L; hl++) {
mxsel = X * m + hx;
xysel = Y * hx + hy;
if (hl == 0)
ksel = 0;
else if (hl == 1)
ksel = 1 + hx;
else if (hl == 2)
ksel = 1 + X + hy;
sum += thetaq[m][hx] * thetaxq[mxsel][hy]
* zeta[xysel][hl] * phi[ksel][wq[m][n]];
} // for h
} // for h
} // for h
loglik += Math.log(sum);
} // for n
} // for m
return Math.exp(-loglik / Wq);
} // ppx
// //////////////// monitor string //////////////////
// describe class and parameters
public String toString() {
return "HPAM1:\nm >> (theta[m] | alpha) >> x[m][n]\n\t"
+ "m, x[m][n] >> (thetax[m,x] | alphax[x]) >> y[m][n]\n\t"
+ "x[m][n], y[m][n] >> (zeta[x,y] | gamma) >> l[m][n]\n\t"
+ "x[m][n], y[m][n], l[m][n] >> (phi[k] | beta) >> w[m][n]\n\t"
+ "\n\t"
+ "k: { \n\t"
+ "if (l==0) k = 0;\n\t"
+ "else if (l==1) k = 1 + x;\n\t"
+ "else if (l==2) k = 1 + X + y; \n\t"
+ "}. \n\t"
+ "\n"
+ String.format(
"Hpam1GibbsSampler: \n"
+ "M = %d Mq = %d W = %d Wq = %d \n"
+ "X = %d Y = %d L = %d V = %d \n"
+ "alpha[0] = %2.5f alphax[0][0] = %2.5f gamma = %2.5f beta = %2.5f ",
M, Mq, W, Wq, X, Y, L, V, alpha[0], alphax[0][0],
gamma, beta);
}
} // Hpam1GibbsSampler