/* * Hpam1p.java -- HPAM1 Gibbs sampling implementation * Copyright (C) 2009-2011 Gregor Heinrich, gregor :: arbylon : net * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as * published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ package org.knowceans.topics.cgen; // imports import java.util.Random; import org.knowceans.corpus.LabelNumCorpus; import org.knowceans.util.CokusRandom; import org.knowceans.util.DirichletEstimation; import org.knowceans.util.StopWatch; import org.knowceans.util.ParallelFor; // /** * Generated Gibbs sampler for the HPAM1 model. Hierarchical PAM model (HPAM1), * testing C1B and C2 structures as well as custom selectors. * *

* Implementation using parallelised samplers. *

* Mixture network specification: * *

 * m >> (theta[m] | alpha) >> x[m][n]
 * m, x[m][n] >> (thetax[m,x] | alphax[x]) >> y[m][n]
 * x[m][n], y[m][n] >> (zeta[x,y] | gamma) >> l[m][n]
 * x[m][n], y[m][n], l[m][n] >> (phi[k] | beta) >> w[m][n]
 * 
 * k: { 
 * if (l==0) k = 0;
 * else if (l==1) k = 1 + x;
 * else if (l==2) k = 1 + X + y; 
 * }. 
 * 
 * elements:
 * 
 *    document m
 *    type: {ROOT|E3COUPLED}
 *    parents: (root)
 *    children: theta thetax
 *    range: M
 * ----
 *    doc-suptop (theta[m] | alpha)
 *    type: {SEQUENCE|C1ROOT}
 *    parents: m
 *    children: x
 *    components: theta, domain: M, range: X
 *    counts: nmx, sum: null
 *    index: m, selector: null
 *    hyperparams: alpha, dimension X, fixed: false 
 * ----
 *    top-subtop (thetax[m,x] | alphax[x])
 *    type: {SEQUENCE|C1BSEQADD}
 *    parents: m x
 *    children: y
 *    components: thetax, domain: M *  X, range: Y
 *    counts: nmxy, sum: nmxysum
 *    index: mx, selector: m,x
 *    hyperparams: alphax, dimension: X * Y, selector: x, fixed: false 
 * ----
 *    suptopic x[m][n]
 *    type: {HIDDEN|E3COUPLED}
 *    parents: theta
 *    children: thetax zeta phi
 *    range: X
 * ----
 *    subtopic y[m][n]
 *    type: {HIDDEN|E3COUPLED}
 *    parents: thetax
 *    children: zeta phi
 *    range: Y
 * ----
 *    toptop-level (zeta[x,y] | gamma)
 *    type: {TOPIC|C2MULTI}
 *    parents: x y
 *    children: l
 *    components: zeta, domain: X *  Y, range: L
 *    counts: nxyl, sum: nxylsum
 *    index: xy, selector: x,y
 *    hyperparams: gamma, dimension 1, fixed: false 
 * ----
 *    hiertop-word (phi[k] | beta)
 *    type: {TOPIC|C2MULTI}
 *    parents: x y l
 *    children: w
 *    components: phi, domain: 1 + X + Y, range: V
 *    counts: nkw, sum: nkwsum
 *    index: k, selector: 
 * if (l==0) k = 0;
 * else if (l==1) k = 1 + x;
 * else if (l==2) k = 1 + X + y;
 *    hyperparams: beta, dimension 1, fixed: false 
 * ----
 *    level l[m][n]
 *    type: {HIDDEN|QFIXED|E1SINGLE}
 *    parents: zeta
 *    children: phi
 *    range: L
 * ----
 *    word w[m][n]
 *    type: {VISIBLE|E1SINGLE}
 *    parents: phi
 *    children: (leaf)
 *    range: V
 * ----
 * sequences:
 * 
 * words [m][n]
 * parent: (root), children: []
 * edges: m x y l w
 * 
* * @author Gregor Heinrich, gregor :: arbylon : net (via MixNetKernelGenerator) * @date generated on 11 Mar 2011 */ // constructor public class Hpam1pGibbsSampler { // //////////////// fields ////////////////// // fields Random[] rand; int iter; int niter; // root edge: document int M; int Mq; // sequence node: doc-suptop int[][] nmx; int[][] nmxq; double[] alpha; double alphasum;// sequence node: top-subtop int[][] nmxy; int[][] nmxyq; int[] nmxysum; int[] nmxysumq; double[][] alphax; double[] alphaxsum;// hidden edge: suptopic int[][] x; int[][] xq; int X; // hidden edge: subtopic int[][] y; int[][] yq; int Y; // topic node: toptop-level int[][] nxyl; int[] nxylsum; double gamma; double gammasum; double[][] zeta; // topic node: hiertop-word int[][] nkw; int[] nkwsum; double beta; double betasum; double[][] phi; // hidden edge: level int[][] l; int[][] lq; int L; // visible edge: word int[][] w; int[][] wq; int V; // sequence: words int W; int Wq; // sampling weights double[][][][] pp; int P; // //////////////// main ////////////////// // standard main routine public static void main(String[] args) { String filebase = "nips/nips"; Random rand = new CokusRandom(); // set up corpus LabelNumCorpus corpus = new LabelNumCorpus(filebase); corpus.split(10, 1, rand); LabelNumCorpus train = (LabelNumCorpus) corpus.getTrainCorpus(); LabelNumCorpus test = (LabelNumCorpus) corpus.getTestCorpus(); // TODO: adjust data source int[][] w = train.getDocWords(rand); int[][] wq = test.getDocWords(rand); int V = corpus.getNumTerms(); // parameters int X = 20; int Y = 20; double alpha = 0.1; double alphax = 0.1; double gamma = 0.1; double beta = 0.1; int niter = 20, niterq = 10; // create sampler Hpam1pGibbsSampler gs = new Hpam1pGibbsSampler(alpha, alphax, X, Y, gamma, beta, w, wq, V, rand); gs.init(); System.out.println(gs); // initial test gs.initq(); gs.runq(niterq); System.out.println(gs.ppx()); // run sampler StopWatch.start(); gs.init(); gs.run(niter); System.out.println(StopWatch.format(StopWatch.stop())); // final test gs.initq(); gs.runq(niterq); System.out.println(gs.ppx()); System.out.println(gs); } // main // //////////////// constructor ////////////////// public Hpam1pGibbsSampler(double alpha, double alphax, int X, int Y, double gamma, double beta, int[][] w, int[][] wq, int V, Random rand) { // assign this.X = X; this.Y = Y; this.w = w; this.wq = wq; this.V = V; this.alpha = new double[X]; for (int i = 0; i < X; i++) { this.alpha[i] = alpha; } this.alphasum = X * alpha; this.alphax = new double[X][Y]; this.alphaxsum = new double[X]; for (int mxj = 0; mxj < X; mxj++) { for (int t = 0; t < Y; t++) { this.alphax[mxj][t] = alphax; } // for t this.alphaxsum[mxj] = Y * alphax; } // for mxj this.gamma = gamma; this.gammasum = L * gamma; this.beta = beta; this.betasum = V * beta; // constants this.L = 3; // count tokens M = w.length; W = 0; for (int m = 0; m < M; m++) { W += w[m].length; } Mq = wq.length; Wq = 0; for (int m = 0; m < Mq; m++) { Wq += wq[m].length; } P = Runtime.getRuntime().availableProcessors(); this.rand = new CokusRandom[P]; // all processors for (int p = 0; p < P; p++) { this.rand[p] = new CokusRandom(rand.nextInt()); } // for p // allocate sampling weights pp = new double[P][X][Y][L]; } // c'tor // //////////////// initialisation ////////////////// // initialisation public void init() { // component selectors int mxsel = -1; int mxjsel = -1; int xysel = -1; int ksel = -1; // sequence node nmx = new int[M][X]; // sequence node nmxy = new int[M * X][Y]; nmxysum = new int[M * X]; // hidden edge x = new int[M][]; for (int m = 0; m < M; m++) { x[m] = new int[w[m].length]; } // hidden edge y = new int[M][]; for (int m = 0; m < M; m++) { y[m] = new int[w[m].length]; } // topic node nxyl = new int[X * Y][L]; nxylsum = new int[X * Y]; // topic node nkw = new int[1 + X + Y][V]; nkwsum = new int[1 + X + Y]; // hidden edge l = new int[M][]; for (int m = 0; m < M; m++) { l[m] = new int[w[m].length]; } // initialise randomly // major loop, sequence [m][n] for (int m = 0; m < M; m++) { // minor loop, sequence [m][n] for (int n = 0; n < w[m].length; n++) { // sample edge values int hx = rand[0].nextInt(X); int hy = rand[0].nextInt(Y); int hl = rand[0].nextInt(L); // assign topics x[m][n] = hx; y[m][n] = hy; l[m][n] = hl; // increment counts nmx[m][hx]++; mxsel = X * m + hx; nmxy[mxsel][hy]++; nmxysum[mxsel]++; xysel = Y * hx + hy; synchronized (nxyl[xysel]) { nxyl[xysel][hl]++; nxylsum[xysel]++; } // sync if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; synchronized (nkw[ksel]) { nkw[ksel][w[m][n]]++; nkwsum[ksel]++; } // sync } // for n } // for m } // init // //////////////// query initialisation ////////////////// // initialisation public void initq() { // component selectors int mxsel = -1; int mxjsel = -1; int xysel = -1; int ksel = -1; // sequence node nmxq = new int[Mq][X]; // sequence node nmxyq = new int[Mq * X][Y]; nmxysumq = new int[Mq * X]; // hidden edge xq = new int[Mq][]; for (int m = 0; m < Mq; m++) { xq[m] = new int[wq[m].length]; } // hidden edge yq = new int[Mq][]; for (int m = 0; m < Mq; m++) { yq[m] = new int[wq[m].length]; } // topic node // compute parameters zeta = new double[X * Y][L]; for (int hx = 0; hx < X; hx++) { for (int hy = 0; hy < Y; hy++) { for (int t = 0; t < L; t++) { xysel = Y * hx + hy; zeta[xysel][t] = (nxyl[xysel][t] + gamma) / (nxylsum[xysel] + gammasum); } // t } // h } // h // topic node // compute parameters phi = new double[1 + X + Y][V]; for (int hx = 0; hx < X; hx++) { for (int hy = 0; hy < Y; hy++) { for (int hl = 0; hl < L; hl++) { for (int t = 0; t < V; t++) { if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; phi[ksel][t] = (nkw[ksel][t] + beta) / (nkwsum[ksel] + betasum); } // t } // h } // h } // h // hidden edge lq = new int[Mq][]; for (int m = 0; m < Mq; m++) { lq[m] = new int[wq[m].length]; } // initialise randomly // major loop, sequence [m][n] for (int m = 0; m < Mq; m++) { // minor loop, sequence [m][n] for (int n = 0; n < wq[m].length; n++) { // sample edge values int hx = rand[0].nextInt(X); int hy = rand[0].nextInt(Y); // assign topics xq[m][n] = hx; yq[m][n] = hy; // increment counts nmxq[m][hx]++; mxsel = X * m + hx; nmxyq[mxsel][hy]++; nmxysumq[mxsel]++; } // for n } // for m } // initq // //////////////// main kernel ////////////////// // Gibbs kernel public void run(int niter) { // for each document // parallel foreach m < M ParallelFor parfor = new ParallelFor(P) { @Override public void process(int m, int thread) { // component selectors int mxsel = -1; int mxjsel = -1; int xysel = -1; int ksel = -1; // minor loop, sequence [m][n] for (int n = 0; n < w[m].length; n++) { double psum; double u; // assign topics int hx = x[m][n]; int hy = y[m][n]; int hl = l[m][n]; // decrement counts nmx[m][hx]--; mxsel = X * m + hx; nmxy[mxsel][hy]--; nmxysum[mxsel]--; xysel = Y * hx + hy; synchronized (nxyl[xysel]) { nxyl[xysel][hl]--; nxylsum[xysel]--; } // sync if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; synchronized (nkw[ksel]) { nkw[ksel][w[m][n]]--; nkwsum[ksel]--; } // sync // compute weights /* * &p(x_{m,n} \eq x, y_{m,n} \eq y, l_{m,n} \eq l\; |\;\vec * x_{-m,n}, \vec y_{-m,n}, \vec l_{-m,n}, \vec w, \cdot) * \notag\\ &\qquad\propto (n^{-mn}_{m,x} + \alpha_{x} ) * \cdot \frac{n^{-mn}_{mx,y} + \alpha^\text{x}_{x, y} * }{\sum_{y} n^{-mn}_{mx,y} + \alpha^\text{x}_{x, y}} \cdot * \frac{n^{-mn}_{xy,l} + \gamma }{\sum_{l} n^{-mn}_{xy,l} + * \gamma} \cdot \frac{n^{-mn}_{k,w} + \beta }{\sum_{w} * n^{-mn}_{k,w} + \beta} */ psum = 0; // hidden edge for (hx = 0; hx < X; hx++) { // hidden edge for (hy = 0; hy < Y; hy++) { // hidden edge for (hl = 0; hl < L; hl++) { mxsel = X * m + hx; mxjsel = hx; xysel = Y * hx + hy; if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; pp[thread][hx][hy][hl] = (nmx[m][hx] + alpha[hx]) * (nmxy[mxsel][hy] + alphax[mxjsel][hy]) / (nmxysum[mxsel] + alphaxsum[mxjsel]) * (nxyl[xysel][hl] + gamma) / (nxylsum[xysel] + gammasum) * (nkw[ksel][w[m][n]] + beta) / (nkwsum[ksel] + betasum); psum += pp[thread][hx][hy][hl]; } // for h } // for h } // for h // sample topics u = rand[thread].nextDouble() * psum; psum = 0; SAMPLED: // each edge value for (hx = 0; hx < X; hx++) { // each edge value for (hy = 0; hy < Y; hy++) { // each edge value for (hl = 0; hl < L; hl++) { psum += pp[thread][hx][hy][hl]; if (u <= psum) break SAMPLED; } // h } // h } // h // assign topics x[m][n] = hx; y[m][n] = hy; l[m][n] = hl; // increment counts nmx[m][hx]++; mxsel = X * m + hx; nmxy[mxsel][hy]++; nmxysum[mxsel]++; xysel = Y * hx + hy; synchronized (nxyl[xysel]) { nxyl[xysel][hl]++; nxylsum[xysel]++; } // sync if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; synchronized (nkw[ksel]) { nkw[ksel][w[m][n]]++; nkwsum[ksel]++; } // sync } // for n } // process }; // foreach m // Gibbs iterations for (int iter = 0; iter < niter; iter++) { System.out.println(iter); // parallel loop parfor.loop(M); // estimate hyperparameters estAlpha(); } // for iter parfor.shutdown(); } // for run // //////////////// query kernel ////////////////// // Gibbs kernel public void runq(int niterq) { // for each document // parallel foreach m < Mq ParallelFor parfor = new ParallelFor(P) { @Override public void process(int m, int thread) { // component selectors int mxsel = -1; int mxjsel = -1; int xysel = -1; int ksel = -1; // minor loop, sequence [m][n] for (int n = 0; n < wq[m].length; n++) { double psum; double u; // assign topics int hx = xq[m][n]; int hy = yq[m][n]; int hl; // decrement counts nmxq[m][hx]--; mxsel = X * m + hx; nmxyq[mxsel][hy]--; nmxysumq[mxsel]--; // compute weights /* * &p(x_{m,n} \eq x, y_{m,n} \eq y\; |\;\vec x_{-m,n}, \vec * y_{-m,n}, \vec w, \cdot) \notag\\ &\qquad\propto * (n^{-mn}_{m,x} + \alpha_{x} ) \cdot \frac{n^{-mn}_{mx,y} * + \alpha^\text{x}_{x, y} }{\sum_{y} n^{-mn}_{mx,y} + * \alpha^\text{x}_{x, y}} \cdot \zeta_{xy,l} \cdot * \phi_{k,w} */ psum = 0; // hidden edge for (hx = 0; hx < X; hx++) { // hidden edge for (hy = 0; hy < Y; hy++) { // hidden edge for (hl = 0; hl < L; hl++) { mxsel = X * m + hx; mxjsel = hx; xysel = Y * hx + hy; if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; pp[thread][hx][hy][hl] = (nmxq[m][hx] + alpha[hx]) * (nmxyq[mxsel][hy] + alphax[mxjsel][hy]) / (nmxysumq[mxsel] + alphaxsum[mxjsel]) * zeta[xysel][hl] * phi[ksel][wq[m][n]]; psum += pp[thread][hx][hy][hl]; } // for h } // for h } // for h // sample topics u = rand[thread].nextDouble() * psum; psum = 0; SAMPLED: // each edge value for (hx = 0; hx < X; hx++) { // each edge value for (hy = 0; hy < Y; hy++) { // each edge value for (hl = 0; hl < L; hl++) { psum += pp[thread][hx][hy][hl]; if (u <= psum) break SAMPLED; } // h } // h } // h // assign topics xq[m][n] = hx; yq[m][n] = hy; // increment counts nmxq[m][hx]++; mxsel = X * m + hx; nmxyq[mxsel][hy]++; nmxysumq[mxsel]++; } // for n } // process }; // foreach m // Gibbs iterations for (int iter = 0; iter < niterq; iter++) { System.out.println(iter); // parallel loop parfor.loop(Mq); } // for iter parfor.shutdown(); } // for runq // //////////////// hyperparameters ////////////////// // update hyperparameters public void estAlpha() { if (iter < 15) { return; } // component selectors int mxsel = -1; int mxjsel = -1; int xysel = -1; int ksel = -1; // Note: assuming non-informative gamma priors (1,0) // hyperparameter for theta int[] nmxsum = new int[M]; // all components for (int m = 0; m < M; m++) { nmxsum[m] = w[m].length; } // for m double[] xalpha = DirichletEstimation.estimateAlphaMap(nmx, nmxsum, alpha, 1., 0.); if (alpha[0] < 2.) { for (int t = 0; t < X; t++) { alpha[t] = (xalpha[t] + alpha[t]) / 2; } // for t } // < 2 // hyperparameter for thetax // filter nkt and nktsum though jSel index space. int[] mx2j = new int[Mq * X]; // for parent values for (int m = 0; m < M; m++) { // for parent values for (int hx = 0; hx < X; hx++) { mxsel = X * m + hx; mxjsel = hx; mx2j[mxsel] = mxjsel; } // for h } // for h double[][] xalphax = DirichletEstimation.estimateAlphaMapSub(nmxy, nmxysum, mx2j, alphax, 1., 0.); if (alphax[0][0] < 2.) { // all component groups for (int j = 0; j < X; j++) { // all values for (int t = 0; t < Y; t++) { alphax[j][t] = (xalphax[j][t] + alphax[j][t]) / 2; } // for t } // for j } // < 2 // hyperparameter for zeta double xgamma = DirichletEstimation.estimateAlphaMap(nxyl, nxylsum, gamma, 1., 0.); if (gamma < 2.) { gamma = (gamma + xgamma) / 2; } // < 2 // hyperparameter for phi double xbeta = DirichletEstimation.estimateAlphaMap(nkw, nkwsum, beta, 1., 0.); if (beta < 2.) { beta = (beta + xbeta) / 2; } // < 2 } // updateHyper // //////////////// perplexity ////////////////// // calculate perplexity value public double ppx() { double loglik = 0; // component selectors int mxsel = -1; int mxjsel = -1; int xysel = -1; int ksel = -1; // compute sequence node parameters double[][] thetaq = new double[Mq][X]; // for parent values for (int m = 0; m < Mq; m++) { int nmxmq = 0; for (int x = 0; x < X; x++) { nmxmq += nmxq[m][x]; } // for x for (int x = 0; x < X; x++) { thetaq[m][x] = (nmxq[m][x] + alpha[x]) / (nmxmq + alphasum); } // for x } // for parent h double[][] thetaxq = new double[Mq * X][Y]; // for parent values for (int m = 0; m < Mq; m++) { // for parent values for (int hx = 0; hx < X; hx++) { for (int y = 0; y < Y; y++) { mxsel = X * m + hx; mxjsel = hx; thetaxq[mxsel][y] = (nmxyq[mxsel][y] + alphax[mxjsel][y]) / (nmxysumq[mxsel] + alphaxsum[mxjsel]); } // for y } // for parent h } // for parent h // compute ppx // major loop, sequence [m][n] for (int m = 0; m < Mq; m++) { // minor loop, sequence [m][n] for (int n = 0; n < wq[m].length; n++) { double sum = 0; // hidden edge for (int hx = 0; hx < X; hx++) { // hidden edge for (int hy = 0; hy < Y; hy++) { // hidden edge for (int hl = 0; hl < L; hl++) { mxsel = X * m + hx; xysel = Y * hx + hy; if (hl == 0) ksel = 0; else if (hl == 1) ksel = 1 + hx; else if (hl == 2) ksel = 1 + X + hy; sum += thetaq[m][hx] * thetaxq[mxsel][hy] * zeta[xysel][hl] * phi[ksel][wq[m][n]]; } // for h } // for h } // for h loglik += Math.log(sum); } // for n } // for m return Math.exp(-loglik / Wq); } // ppx // //////////////// monitor string ////////////////// // describe class and parameters public String toString() { return "HPAM1:\nm >> (theta[m] | alpha) >> x[m][n]\n\t" + "m, x[m][n] >> (thetax[m,x] | alphax[x]) >> y[m][n]\n\t" + "x[m][n], y[m][n] >> (zeta[x,y] | gamma) >> l[m][n]\n\t" + "x[m][n], y[m][n], l[m][n] >> (phi[k] | beta) >> w[m][n]\n\t" + "\n\t" + "k: { \n\t" + "if (l==0) k = 0;\n\t" + "else if (l==1) k = 1 + x;\n\t" + "else if (l==2) k = 1 + X + y; \n\t" + "}. \n\t" + "\n" + String.format( "Hpam1pGibbsSampler: P = %d\n" + "M = %d Mq = %d W = %d Wq = %d \n" + "X = %d Y = %d L = %d V = %d \n" + "alpha[0] = %2.5f alphax[0][0] = %2.5f gamma = %2.5f beta = %2.5f ", P, M, Mq, W, Wq, X, Y, L, V, alpha[0], alphax[0][0], gamma, beta); } } // Hpam1pGibbsSampler