/*
 * Decompiled with CFR 0.152.
 */
package org.knowceans.topics.simple;

import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.SortedSet;
import java.util.TreeSet;
import org.knowceans.corpus.NumCorpus;
import org.knowceans.corpus.VisCorpus;
import org.knowceans.topics.simple.ISimpleGibbs;
import org.knowceans.topics.simple.ISimplePpx;
import org.knowceans.topics.simple.ISimpleQueryGibbs;
import org.knowceans.topics.simple.LdaTopics;
import org.knowceans.topics.simple.TopicMatrixPanel;
import org.knowceans.util.ArrayIo;
import org.knowceans.util.ArrayUtils;
import org.knowceans.util.CokusRandom;
import org.knowceans.util.DirichletEstimation;
import org.knowceans.util.IndexQuickSort;
import org.knowceans.util.RandomSamplers;
import org.knowceans.util.StopWatch;
import org.knowceans.util.Vectors;

public class IldaGibbs
implements ISimpleGibbs,
ISimpleQueryGibbs,
ISimplePpx {
    private static TopicMatrixPanel vis;
    private int[][] w;
    private int[][] wq;
    private SortedSet<Integer> kgaps;
    private List<Integer> kactive;
    private List<Integer>[] nmk;
    private int[][] nmkq;
    private List<int[]> nkt;
    private List<Integer> nk;
    private double[][] phi;
    private int[][] z;
    private int[][] zq;
    private double[] pp;
    public final int ppstep = 10;
    private double alpha;
    private ArrayList<Double> tau;
    private double beta;
    private double gamma;
    double aalpha = 5.0;
    double balpha = 0.1;
    double abeta = 0.1;
    double bbeta = 0.1;
    double agamma = 5.0;
    double bgamma = 0.1;
    int R = 10;
    private double T;
    private Random rand;
    RandomSamplers samp;
    private int iter;
    private int K;
    private int M;
    private int Mq;
    private int Wq;
    private int V;
    private boolean inited = false;
    private boolean fixedK = false;
    private boolean fixedHyper = false;

    public static void main(String[] args) {
        NumCorpus corpus;
        int niter = 1500;
        int niterq = 10;
        String filebase = "nips/nips";
        boolean usefile = false;
        boolean display = true;
        CokusRandom rand = new CokusRandom(56567651L);
        if (usefile) {
            corpus = new NumCorpus(String.valueOf(filebase) + ".corpus");
            if (display) {
                vis = new TopicMatrixPanel(900, 400, (int)Math.sqrt(corpus.getNumTerms()), 1);
            }
        } else {
            int K = 10;
            corpus = VisCorpus.generateLdaCorpus(K, 1000, 200);
            if (display) {
                vis = new TopicMatrixPanel(900, 400, K, 300 / K);
            }
        }
        corpus.split(10, 2, rand);
        NumCorpus train = (NumCorpus)corpus.getTrainCorpus();
        NumCorpus test = (NumCorpus)corpus.getTestCorpus();
        int[][] w = train.getDocWords(rand);
        int[][] wq = test.getDocWords(rand);
        int K0 = 0;
        int V = corpus.getNumTerms();
        double alpha = 1.0;
        double beta = 0.1;
        double gamma = 1.5;
        IldaGibbs gs = new IldaGibbs(w, wq, K0, V, alpha, beta, gamma, rand);
        gs.init();
        System.out.println("initialised");
        System.out.println(gs);
        gs.initq();
        gs.runq(niterq);
        System.out.println("perplexity = " + gs.ppx());
        StopWatch.start();
        System.out.println("starting Gibbs sampler with " + niter + " iterations");
        gs.run(niter);
        System.out.println(StopWatch.format(StopWatch.stop()));
        gs.initq();
        gs.runq(niterq);
        System.out.println("perplexity = " + gs.ppx());
        System.out.println(gs);
        gs.packTopics();
        System.out.println("finished");
        System.out.println(gs);
        if (!usefile) {
            try {
                PrintStream bw = new PrintStream(String.valueOf(filebase) + ".ilda.result");
                gs.print(bw, filebase, corpus.getOrigDocIds()[0], train.getNumWords());
                bw.close();
                System.out.println("done");
            }
            catch (FileNotFoundException e) {
                e.printStackTrace();
            }
        }
    }

    private void print(PrintStream out, String filebase, int[] docids, int W) {
        this.phi = new double[this.K][this.V];
        double[][] theta = new double[this.M][this.K];
        int k = 0;
        while (k < this.K) {
            int t = 0;
            while (t < this.V) {
                this.phi[k][t] = ((double)this.nkt.get(k)[t] + this.beta) / ((double)this.nk.get(k).intValue() + this.beta * (double)this.V);
                ++t;
            }
            ++k;
        }
        int m = 0;
        while (m < this.M) {
            int k2 = 0;
            while (k2 < this.K) {
                theta[m][k2] = ((double)this.nmk[m].get(k2).intValue() + this.alpha) / ((double)this.w[m].length + this.alpha * (double)this.K);
                ++k2;
            }
            ++m;
        }
        LdaTopics lt = new LdaTopics(filebase, theta, this.phi);
        int m2 = 0;
        while (m2 < this.M) {
            out.println(lt.printDocument(m2, docids[m2], 10, false, true));
            ++m2;
        }
        int[] kk = IndexQuickSort.sort(this.nk);
        IndexQuickSort.reverse(kk);
        int k3 = 0;
        while (k3 < this.K) {
            out.println(String.format("%d (%2.5f / %d): %s", k3, (double)this.nk.get(kk[k3]).intValue() / (double)W * (double)this.K, this.K, lt.printTopic(kk[k3], 20)));
            ++k3;
        }
        ArrayIo.saveBinaryMatrix(String.valueOf(filebase) + ".ilda.theta.zip", theta);
        ArrayIo.saveBinaryMatrix(String.valueOf(filebase) + ".ilda.phi.zip", this.phi);
    }

    public IldaGibbs(int[][] w, int[][] wq, int K, int V, double alpha, double beta, double gamma, Random rand) {
        this.w = w;
        this.wq = wq;
        this.K = K;
        this.alpha = alpha;
        this.beta = beta;
        this.gamma = gamma;
        if (gamma == 0.0) {
            this.fixedK = true;
        }
        this.M = w.length;
        this.Mq = wq.length;
        this.V = V;
        this.rand = rand;
        this.samp = new RandomSamplers(rand);
    }

    @Override
    public void init() {
        this.nmk = new ArrayList[this.M];
        this.nkt = new ArrayList<int[]>();
        this.nk = new ArrayList<Integer>();
        this.z = new int[this.M][];
        int m = 0;
        while (m < this.M) {
            this.nmk[m] = new ArrayList<Integer>();
            int k = 0;
            while (k < this.K) {
                this.nmk[m].add(0);
                ++k;
            }
            this.z[m] = new int[this.w[m].length];
            ++m;
        }
        this.kactive = new ArrayList<Integer>();
        this.kgaps = new TreeSet<Integer>();
        this.tau = new ArrayList();
        int k = 0;
        while (k < this.K) {
            this.kactive.add(k);
            this.nkt.add(new int[this.V]);
            this.nk.add(0);
            this.tau.add(1.0 / (double)this.K);
            ++k;
        }
        this.tau.add(1.0 / (double)this.K);
        this.pp = new double[this.K + 10];
        this.run(1);
        if (!this.fixedK) {
            this.updateTau();
        }
        this.inited = true;
    }

    @Override
    public void initq() {
        int Kg = this.K + this.kgaps.size();
        this.phi = new double[Kg][this.V];
        int kk = 0;
        while (kk < this.K) {
            int k = this.kactive.get(kk);
            int t = 0;
            while (t < this.V) {
                this.phi[k][t] = ((double)this.nkt.get(k)[t] + this.beta) / ((double)this.nk.get(k).intValue() + (double)this.V * this.beta);
                ++t;
            }
            ++kk;
        }
        this.nmkq = new int[this.Mq][Kg];
        this.zq = new int[this.Mq][];
        this.Wq = 0;
        int m = 0;
        while (m < this.Mq) {
            this.zq[m] = new int[this.wq[m].length];
            int n = 0;
            while (n < this.wq[m].length) {
                int k;
                this.zq[m][n] = k = this.rand.nextInt(this.K);
                int[] nArray = this.nmkq[m];
                int n2 = k;
                nArray[n2] = nArray[n2] + 1;
                ++this.Wq;
                ++n;
            }
            ++m;
        }
    }

    @Override
    public void run(int niter) {
        this.iter = 0;
        while (this.iter < niter) {
            System.out.println(this.iter);
            System.out.println(this);
            int m = 0;
            while (m < this.M) {
                int n = 0;
                while (n < this.w[m].length) {
                    int k;
                    int kold = -1;
                    int t = this.w[m][n];
                    if (this.inited) {
                        k = this.z[m][n];
                        this.nmk[m].set(k, this.nmk[m].get(k) - 1);
                        int[] nArray = this.nkt.get(k);
                        int n2 = t;
                        nArray[n2] = nArray[n2] - 1;
                        this.nk.set(k, this.nk.get(k) - 1);
                        kold = k;
                    }
                    double psum = 0.0;
                    int kk = 0;
                    while (kk < this.K) {
                        k = this.kactive.get(kk);
                        this.pp[kk] = ((double)this.nmk[m].get(k).intValue() + this.alpha * this.tau.get(k)) * ((double)this.nkt.get(k)[t] + this.beta) / ((double)this.nk.get(k).intValue() + (double)this.V * this.beta);
                        psum += this.pp[kk];
                        ++kk;
                    }
                    if (!this.fixedK) {
                        this.pp[this.K] = this.alpha * this.tau.get(this.K) / (double)this.V;
                        psum += this.pp[this.K];
                    }
                    double u = this.rand.nextDouble();
                    u *= psum;
                    psum = 0.0;
                    int kk2 = 0;
                    while (kk2 < this.K + 1) {
                        if (u <= (psum += this.pp[kk2])) break;
                        ++kk2;
                    }
                    if (kk2 < this.K) {
                        this.z[m][n] = k = this.kactive.get(kk2).intValue();
                        this.nmk[m].set(k, this.nmk[m].get(k) + 1);
                        int[] nArray = this.nkt.get(k);
                        int n3 = t;
                        nArray[n3] = nArray[n3] + 1;
                        this.nk.set(k, this.nk.get(k) + 1);
                    } else {
                        assert (!this.fixedK);
                        this.z[m][n] = this.spawnTopic(m, t);
                        this.updateTau();
                        System.out.println("K = " + this.K);
                    }
                    if (this.inited && this.nk.get(kold) == 0) {
                        this.kactive.remove((Object)kold);
                        this.kgaps.add(kold);
                        assert (Vectors.sum(this.nkt.get(kold)) == 0 && this.nk.get(kold) == 0 && this.nmk[m].get(kold) == 0);
                        --this.K;
                        System.out.println("K = " + this.K);
                        this.updateTau();
                    }
                    ++n;
                }
                ++m;
            }
            if (vis != null) {
                vis.setTopics(this.nkt);
            }
            if (!this.fixedK) {
                this.updateTau();
            }
            if (this.iter > 10 && !this.fixedHyper) {
                this.updateHyper();
            }
            ++this.iter;
        }
    }

    @Override
    public void runq(int niter) {
        int qiter = 0;
        while (qiter < niter) {
            int m = 0;
            while (m < this.nmkq.length) {
                int n = 0;
                while (n < this.wq[m].length) {
                    int k = this.zq[m][n];
                    int t = this.wq[m][n];
                    int[] nArray = this.nmkq[m];
                    int n2 = k;
                    nArray[n2] = nArray[n2] - 1;
                    double psum = 0.0;
                    int kk = 0;
                    while (kk < this.K) {
                        this.pp[kk] = ((double)this.nmkq[m][kk] + this.alpha) * this.phi[kk][t];
                        psum += this.pp[kk];
                        ++kk;
                    }
                    double u = this.rand.nextDouble() * psum;
                    psum = 0.0;
                    int kk2 = 0;
                    while (kk2 < this.K) {
                        if (u <= (psum += this.pp[kk2])) break;
                        ++kk2;
                    }
                    this.zq[m][n] = kk2;
                    int[] nArray2 = this.nmkq[m];
                    int n3 = kk2;
                    nArray2[n3] = nArray2[n3] + 1;
                    ++n;
                }
                ++m;
            }
            ++qiter;
        }
    }

    private int spawnTopic(int m, int t) {
        int k;
        if (this.kgaps.size() > 0) {
            k = this.kgaps.first();
            this.kgaps.remove(k);
            this.kactive.add(k);
            this.nmk[m].set(k, 1);
            this.nkt.get((int)k)[t] = 1;
            this.nk.set(k, 1);
        } else {
            k = this.K;
            int i = 0;
            while (i < this.M) {
                this.nmk[i].add(0);
                ++i;
            }
            this.kactive.add(this.K);
            this.nmk[m].set(this.K, 1);
            this.nkt.add(new int[this.V]);
            this.nkt.get((int)this.K)[t] = 1;
            this.nk.add(1);
            this.tau.add(0.0);
        }
        ++this.K;
        if (this.pp.length <= this.K) {
            this.pp = new double[this.K + 10];
        }
        return k;
    }

    public void packTopics() {
        int i;
        int[] knew2k = IndexQuickSort.sort(this.nk);
        IndexQuickSort.reverse(knew2k);
        IndexQuickSort.reorder(this.nk, knew2k);
        IndexQuickSort.reorder(this.nkt, knew2k);
        int i2 = 0;
        while (i2 < this.kgaps.size()) {
            this.nk.remove(this.nk.size() - 1);
            this.nkt.remove(this.nkt.size() - 1);
            ++i2;
        }
        int m = 0;
        while (m < this.M) {
            IndexQuickSort.reorder(this.nmk[m], knew2k);
            i = 0;
            while (i < this.kgaps.size()) {
                this.nmk[m].remove(this.nmk[m].size() - 1);
                ++i;
            }
            ++m;
        }
        this.kgaps.clear();
        int[] k2knew = IndexQuickSort.inverse(knew2k);
        i = 0;
        while (i < this.K) {
            this.kactive.set(i, k2knew[this.kactive.get(i)]);
            ++i;
        }
        int m2 = 0;
        while (m2 < this.M) {
            int n = 0;
            while (n < this.w[m2].length) {
                this.z[m2][n] = k2knew[this.z[m2][n]];
                ++n;
            }
            ++m2;
        }
    }

    private void updateTau() {
        double[] mk = new double[this.K + 1];
        int kk = 0;
        while (kk < this.K) {
            int k = this.kactive.get(kk);
            int m = 0;
            while (m < this.M) {
                if (this.nmk[m].get(k) > 1) {
                    int n = kk;
                    mk[n] = mk[n] + (double)this.samp.randAntoniak(this.alpha * this.tau.get(k), this.nmk[m].get(k));
                } else {
                    int n = kk;
                    mk[n] = mk[n] + (double)this.nmk[m].get(k).intValue();
                }
                ++m;
            }
            ++kk;
        }
        this.T = Vectors.sum(mk);
        mk[this.K] = this.gamma;
        double[] tt = this.samp.randDir(mk);
        int kk2 = 0;
        while (kk2 < this.K) {
            int k = this.kactive.get(kk2);
            this.tau.set(k, tt[kk2]);
            ++kk2;
        }
        this.tau.set(this.K, tt[this.K]);
    }

    private void updateHyper() {
        int r = 0;
        while (r < this.R) {
            double eta = this.samp.randBeta(this.gamma + 1.0, this.T);
            double bloge = this.bgamma - Math.log(eta);
            double pie = 1.0 / (1.0 + this.T * bloge / (this.agamma + (double)this.K - 1.0));
            int u = this.samp.randBernoulli(pie);
            this.gamma = this.samp.randGamma(this.agamma + (double)this.K - 1.0 + (double)u, 1.0 / bloge);
            double qs = 0.0;
            double qw = 0.0;
            int m = 0;
            while (m < this.M) {
                qs += (double)this.samp.randBernoulli((double)this.w[m].length / ((double)this.w[m].length + this.alpha));
                qw += Math.log(this.samp.randBeta(this.alpha + 1.0, this.w[m].length));
                ++m;
            }
            this.alpha = this.samp.randGamma(this.aalpha + this.T - qs, 1.0 / (this.balpha - qw));
            ++r;
        }
        int[] ak = (int[])ArrayUtils.asPrimitiveArray(this.nk);
        int[][] akt = new int[this.K][this.V];
        int k = 0;
        while (k < this.K) {
            akt[k] = this.nkt.get(k);
            ++k;
        }
        this.beta = DirichletEstimation.estimateAlphaMap(akt, ak, this.beta, this.abeta, this.bbeta);
    }

    @Override
    public double ppx() {
        double loglik = 0.0;
        double[][] thetaq = new double[this.Mq][this.K];
        int m = 0;
        while (m < this.Mq) {
            int k = 0;
            while (k < this.K) {
                thetaq[m][k] = ((double)this.nmkq[m][k] + this.alpha) / ((double)this.wq[m].length + (double)this.K * this.alpha);
                ++k;
            }
            ++m;
        }
        m = 0;
        while (m < this.Mq) {
            int n = 0;
            while (n < this.wq[m].length) {
                double sum = 0.0;
                int k = 0;
                while (k < this.K) {
                    sum += thetaq[m][k] * this.phi[k][this.wq[m][n]];
                    ++k;
                }
                loglik += Math.log(sum);
                ++n;
            }
            ++m;
        }
        return Math.exp(-loglik / (double)this.Wq);
    }

    public String toString() {
        return String.format("ILDA: M = %d, K = %d, V = %d, alpha = %2.5f, beta = %2.5f, gamma = %2.5f", this.M, this.K, this.V, this.alpha, this.beta, this.gamma);
    }
}

