/*
 * Decompiled with CFR 0.152.
 */
package org.knowceans.topics.simple;

import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.Random;
import org.knowceans.corpus.NumCorpus;
import org.knowceans.corpus.VisCorpus;
import org.knowceans.topics.simple.ISimpleGibbs;
import org.knowceans.topics.simple.ISimplePpx;
import org.knowceans.topics.simple.ISimpleQueryGibbs;
import org.knowceans.topics.simple.LdaTopics;
import org.knowceans.topics.simple.TopicMatrixPanel;
import org.knowceans.util.ArrayIo;
import org.knowceans.util.CokusRandom;
import org.knowceans.util.IndexQuickSort;
import org.knowceans.util.StopWatch;
import org.knowceans.util.Vectors;

public class LdaGibbs
implements ISimpleGibbs,
ISimpleQueryGibbs,
ISimplePpx {
    private static TopicMatrixPanel vis;
    private int[][] w;
    private int[][] wq;
    private int[][] nmk;
    private int[][] nmkq;
    private int[][] nkt;
    private int[] nk;
    private double[][] phi;
    private int[][] z;
    private int[][] zq;
    private double alpha;
    private double beta;
    private Random rand;
    private int iter;
    private int K;
    private int M;
    private int Mq;
    private int Wq;
    private int V;

    public static void main(String[] args) {
        NumCorpus corpus;
        int niter = 500;
        int niterq = 10;
        int K = 50;
        String filebase = "nips/nips";
        boolean usefile = true;
        boolean display = true;
        CokusRandom rand = new CokusRandom(56567651L);
        if (usefile) {
            corpus = new NumCorpus(String.valueOf(filebase) + ".corpus");
            if (display) {
                vis = new TopicMatrixPanel(900, 400, (int)Math.sqrt(corpus.getNumTerms()), 1);
            }
        } else {
            corpus = VisCorpus.generateLdaCorpus(K, 1000, 200);
            if (display) {
                vis = new TopicMatrixPanel(900, 400, K, 300 / K);
            }
        }
        corpus.split(10, 2, rand);
        NumCorpus train = (NumCorpus)corpus.getTrainCorpus();
        NumCorpus test = (NumCorpus)corpus.getTestCorpus();
        int[][] w = train.getDocWords(rand);
        int[][] wq = test.getDocWords(rand);
        int V = corpus.getNumTerms();
        double alpha = 0.1;
        double beta = 0.1;
        LdaGibbs gs = new LdaGibbs(w, wq, K, V, alpha, beta, rand);
        gs.init();
        System.out.println(gs);
        gs.initq();
        gs.runq(niterq);
        System.out.println(gs.ppx());
        StopWatch.start();
        gs.run(niter);
        System.out.println(StopWatch.format(StopWatch.stop()));
        gs.initq();
        gs.runq(niterq);
        System.out.println(gs.ppx());
        if (vis == null) {
            try {
                PrintStream bw = new PrintStream(String.valueOf(filebase) + ".lda.result");
                gs.print(bw, filebase, corpus.getOrigDocIds()[0], train.getNumWords());
                bw.close();
                System.out.println("done");
            }
            catch (FileNotFoundException e) {
                e.printStackTrace();
            }
        }
    }

    private void print(PrintStream out, String filebase, int[] docids, int W) {
        this.phi = new double[this.K][this.V];
        double[][] theta = new double[this.M][this.K];
        int k = 0;
        while (k < this.K) {
            int t = 0;
            while (t < this.V) {
                this.phi[k][t] = ((double)this.nkt[k][t] + this.beta) / ((double)this.nk[k] + this.beta * (double)this.V);
                ++t;
            }
            ++k;
        }
        int m = 0;
        while (m < this.M) {
            int k2 = 0;
            while (k2 < this.K) {
                theta[m][k2] = ((double)this.nmk[m][k2] + this.alpha) / ((double)this.w[m].length + this.alpha * (double)this.K);
                ++k2;
            }
            ++m;
        }
        LdaTopics lt = new LdaTopics(filebase, theta, this.phi);
        int m2 = 0;
        while (m2 < this.M) {
            out.println(lt.printDocument(m2, docids[m2], 10, false, true));
            ++m2;
        }
        int[] kk = IndexQuickSort.sort(this.nk);
        IndexQuickSort.reverse(kk);
        int k3 = 0;
        while (k3 < this.K) {
            out.println(String.format("%d (%2.5f / %d): %s", k3, (double)this.nk[kk[k3]] / (double)W * (double)this.K, this.K, lt.printTopic(kk[k3], 20)));
            ++k3;
        }
        ArrayIo.saveBinaryMatrix(String.valueOf(filebase) + ".lda.theta.zip", theta);
        ArrayIo.saveBinaryMatrix(String.valueOf(filebase) + ".lda.phi.zip", this.phi);
    }

    public LdaGibbs(int[][] w, int[][] wq, int K, int V, double alpha, double beta, Random rand) {
        this.w = w;
        this.wq = wq;
        this.K = K;
        this.M = w.length;
        this.Mq = wq.length;
        this.V = V;
        this.alpha = alpha;
        this.beta = beta;
        this.rand = rand;
    }

    @Override
    public void init() {
        this.nmk = new int[this.M][this.K];
        this.nkt = new int[this.K][this.V];
        this.nk = new int[this.K];
        this.z = new int[this.M][];
        int m = 0;
        while (m < this.M) {
            this.z[m] = new int[this.w[m].length];
            int n = 0;
            while (n < this.w[m].length) {
                int k;
                this.z[m][n] = k = this.rand.nextInt(this.K);
                int[] nArray = this.nmk[m];
                int n2 = k;
                nArray[n2] = nArray[n2] + 1;
                int[] nArray2 = this.nkt[k];
                int n3 = this.w[m][n];
                nArray2[n3] = nArray2[n3] + 1;
                int n4 = k;
                this.nk[n4] = this.nk[n4] + 1;
                ++n;
            }
            ++m;
        }
    }

    @Override
    public void initq() {
        this.phi = new double[this.K][this.V];
        int k = 0;
        while (k < this.K) {
            int t = 0;
            while (t < this.V) {
                this.phi[k][t] = ((double)this.nkt[k][t] + this.beta) / ((double)this.nk[k] + (double)this.V * this.beta);
                ++t;
            }
            ++k;
        }
        this.nmkq = new int[this.Mq][this.K];
        this.zq = new int[this.Mq][];
        this.Wq = 0;
        int m = 0;
        while (m < this.Mq) {
            this.zq[m] = new int[this.wq[m].length];
            int n = 0;
            while (n < this.wq[m].length) {
                int k2;
                this.zq[m][n] = k2 = this.rand.nextInt(this.K);
                int[] nArray = this.nmkq[m];
                int n2 = k2;
                nArray[n2] = nArray[n2] + 1;
                ++this.Wq;
                ++n;
            }
            ++m;
        }
    }

    @Override
    public void run(int niter) {
        double[] pp = new double[this.K];
        this.iter = 0;
        while (this.iter < niter) {
            System.out.println(this.iter);
            int m = 0;
            while (m < this.M) {
                int n = 0;
                while (n < this.w[m].length) {
                    int k = this.z[m][n];
                    int t = this.w[m][n];
                    int[] nArray = this.nmk[m];
                    int n2 = k;
                    nArray[n2] = nArray[n2] - 1;
                    int[] nArray2 = this.nkt[k];
                    int n3 = t;
                    nArray2[n3] = nArray2[n3] - 1;
                    int n4 = k;
                    this.nk[n4] = this.nk[n4] - 1;
                    double psum = 0.0;
                    int kk = 0;
                    while (kk < pp.length) {
                        pp[kk] = ((double)this.nmk[m][kk] + this.alpha) * ((double)this.nkt[kk][t] + this.beta) / ((double)this.nk[kk] + (double)this.V * this.beta);
                        psum += pp[kk];
                        ++kk;
                    }
                    double u = this.rand.nextDouble() * psum;
                    psum = 0.0;
                    int kk2 = 0;
                    kk2 = 0;
                    while (kk2 < pp.length) {
                        if (u <= (psum += pp[kk2])) break;
                        ++kk2;
                    }
                    this.z[m][n] = kk2;
                    int[] nArray3 = this.nmk[m];
                    int n5 = kk2;
                    nArray3[n5] = nArray3[n5] + 1;
                    int[] nArray4 = this.nkt[kk2];
                    int n6 = t;
                    nArray4[n6] = nArray4[n6] + 1;
                    int n7 = kk2;
                    this.nk[n7] = this.nk[n7] + 1;
                    ++n;
                }
                ++m;
            }
            if (vis != null) {
                vis.setTopics(this.nkt);
            }
            ++this.iter;
        }
    }

    @Override
    public void runq(int niter) {
        double[] pp = new double[this.K];
        int qiter = 0;
        while (qiter < niter) {
            int m = 0;
            while (m < this.nmkq.length) {
                int n = 0;
                while (n < this.wq[m].length) {
                    int k = this.zq[m][n];
                    int t = this.wq[m][n];
                    int[] nArray = this.nmkq[m];
                    int n2 = k;
                    nArray[n2] = nArray[n2] - 1;
                    double psum = 0.0;
                    int kk = 0;
                    while (kk < pp.length) {
                        pp[kk] = ((double)this.nmkq[m][kk] + this.alpha) * this.phi[kk][t];
                        psum += pp[kk];
                        ++kk;
                    }
                    double u = this.rand.nextDouble() * psum;
                    psum = 0.0;
                    int kk2 = 0;
                    kk2 = 0;
                    while (kk2 < pp.length) {
                        if (u <= (psum += pp[kk2])) break;
                        ++kk2;
                    }
                    this.zq[m][n] = kk2;
                    try {
                        int[] nArray2 = this.nmkq[m];
                        int n3 = kk2;
                        nArray2[n3] = nArray2[n3] + 1;
                    }
                    catch (Exception e) {
                        System.out.println(Vectors.print(pp));
                    }
                    ++n;
                }
                ++m;
            }
            ++qiter;
        }
    }

    @Override
    public double ppx() {
        double loglik = 0.0;
        double[][] thetaq = new double[this.Mq][this.K];
        int m = 0;
        while (m < this.Mq) {
            int k = 0;
            while (k < this.K) {
                thetaq[m][k] = ((double)this.nmkq[m][k] + this.alpha) / ((double)this.wq[m].length + (double)this.K * this.alpha);
                ++k;
            }
            ++m;
        }
        m = 0;
        while (m < this.Mq) {
            int n = 0;
            while (n < this.wq[m].length) {
                double sum = 0.0;
                int k = 0;
                while (k < this.K) {
                    sum += thetaq[m][k] * this.phi[k][this.wq[m][n]];
                    ++k;
                }
                loglik += Math.log(sum);
                ++n;
            }
            ++m;
        }
        return Math.exp(-loglik / (double)this.Wq);
    }

    public String toString() {
        return String.format("LDA: M = %d, K = %d, V = %d, alpha = %2.5f, beta = %2.5f", this.M, this.K, this.V, this.alpha, this.beta);
    }
}

