/*
 * Decompiled with CFR 0.152.
 */
package org.knowceans.topics.simple.sandbox;

import java.util.Random;
import org.knowceans.corpus.NumCorpus;
import org.knowceans.topics.simple.ISimpleGibbs;
import org.knowceans.topics.simple.ISimplePpx;
import org.knowceans.topics.simple.ISimpleQueryGibbs;
import org.knowceans.util.CokusRandom;
import org.knowceans.util.ParallelFor;
import org.knowceans.util.StopWatch;
import org.knowceans.util.Vectors;

public class LdaGibbsParA
implements ISimpleGibbs,
ISimpleQueryGibbs,
ISimplePpx {
    private int[][] w;
    private int[][] wq;
    private int[][] nmk;
    private int[][] nmkq;
    private int[][] nkt;
    private int[] nk;
    private double[][] phi;
    private int[][] z;
    private int[][] zq;
    private double alpha;
    private double beta;
    private Random[] rand_;
    private int iter;
    private int K;
    private int M;
    private int Mq;
    private int Wq;
    private int V;
    private int P;

    public static void main(String[] args) {
        int niter = 10;
        int niterq = 10;
        NumCorpus corpus = new NumCorpus("nips/nips.corpus");
        CokusRandom rand = new CokusRandom();
        corpus.split(10, 2, rand);
        NumCorpus train = (NumCorpus)corpus.getTrainCorpus();
        NumCorpus test = (NumCorpus)corpus.getTestCorpus();
        int[][] w = train.getDocWords(rand);
        int[][] wq = test.getDocWords(rand);
        int K = 100;
        int V = corpus.getNumTerms();
        double alpha = 0.1;
        double beta = 0.1;
        int P = 2;
        P = Runtime.getRuntime().availableProcessors();
        LdaGibbsParA gs = new LdaGibbsParA(w, wq, K, V, alpha, beta, rand, P);
        gs.init();
        System.out.println(gs);
        gs.initq();
        gs.runq(niterq);
        System.out.println(gs.ppx());
        StopWatch.start();
        gs.run(niter);
        System.out.println(StopWatch.format(StopWatch.stop()));
        gs.initq();
        gs.runq(niterq);
        System.out.println(gs.ppx());
    }

    public LdaGibbsParA(int[][] w, int[][] wq, int K, int V, double alpha, double beta, Random rand, int P) {
        this.w = w;
        this.wq = wq;
        this.K = K;
        this.M = w.length;
        this.Mq = wq.length;
        this.V = V;
        this.alpha = alpha;
        this.beta = beta;
        this.P = P;
        this.rand_ = new CokusRandom[P];
        int i = 0;
        while (i < P) {
            this.rand_[i] = new CokusRandom(rand.nextInt());
            ++i;
        }
    }

    @Override
    public void init() {
        this.nmk = new int[this.M][this.K];
        this.nkt = new int[this.K][this.V];
        this.nk = new int[this.K];
        this.z = new int[this.M][];
        int m = 0;
        while (m < this.M) {
            this.z[m] = new int[this.w[m].length];
            int n = 0;
            while (n < this.w[m].length) {
                int k;
                this.z[m][n] = k = this.rand_[0].nextInt(this.K);
                int[] nArray = this.nmk[m];
                int n2 = k;
                nArray[n2] = nArray[n2] + 1;
                int[] nArray2 = this.nkt[k];
                int n3 = this.w[m][n];
                nArray2[n3] = nArray2[n3] + 1;
                int n4 = k;
                this.nk[n4] = this.nk[n4] + 1;
                ++n;
            }
            ++m;
        }
    }

    @Override
    public void initq() {
        this.phi = new double[this.K][this.V];
        int k = 0;
        while (k < this.K) {
            int t = 0;
            while (t < this.V) {
                this.phi[k][t] = ((double)this.nkt[k][t] + this.beta) / ((double)this.nk[k] + (double)this.V * this.beta);
                ++t;
            }
            ++k;
        }
        this.nmkq = new int[this.Mq][this.K];
        this.zq = new int[this.Mq][];
        this.Wq = 0;
        int m = 0;
        while (m < this.Mq) {
            this.zq[m] = new int[this.wq[m].length];
            int n = 0;
            while (n < this.wq[m].length) {
                int k2;
                this.zq[m][n] = k2 = this.rand_[0].nextInt(this.K);
                int[] nArray = this.nmkq[m];
                int n2 = k2;
                nArray[n2] = nArray[n2] + 1;
                ++this.Wq;
                ++n;
            }
            ++m;
        }
    }

    @Override
    public void run(int niter) {
        ParallelFor eachdoc = new ParallelFor(this.P){
            double[][] pp;
            {
                this.pp = new double[LdaGibbsParA.this.P][LdaGibbsParA.this.K];
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            public void process(int m, int thread) {
                int n = 0;
                while (n < LdaGibbsParA.this.w[m].length) {
                    int k = LdaGibbsParA.this.z[m][n];
                    int t = LdaGibbsParA.this.w[m][n];
                    int[] nArray = LdaGibbsParA.this.nmk[m];
                    int n2 = k;
                    nArray[n2] = nArray[n2] - 1;
                    int[] nArray2 = LdaGibbsParA.this.nkt[k];
                    synchronized (nArray2) {
                        int[] nArray3 = LdaGibbsParA.this.nkt[k];
                        int n3 = t;
                        nArray3[n3] = nArray3[n3] - 1;
                        int[] nArray4 = LdaGibbsParA.this.nk;
                        int n4 = k;
                        nArray4[n4] = nArray4[n4] - 1;
                    }
                    double psum = 0.0;
                    int kk = 0;
                    while (kk < LdaGibbsParA.this.K) {
                        this.pp[thread][kk] = ((double)LdaGibbsParA.this.nmk[m][kk] + LdaGibbsParA.this.alpha) * ((double)LdaGibbsParA.this.nkt[kk][t] + LdaGibbsParA.this.beta) / ((double)LdaGibbsParA.this.nk[kk] + (double)LdaGibbsParA.this.V * LdaGibbsParA.this.beta);
                        psum += this.pp[thread][kk];
                        ++kk;
                    }
                    double u = LdaGibbsParA.this.rand_[thread].nextDouble() * psum;
                    psum = 0.0;
                    int kk2 = 0;
                    kk2 = 0;
                    while (kk2 < LdaGibbsParA.this.K) {
                        if (u <= (psum += this.pp[thread][kk2])) break;
                        ++kk2;
                    }
                    ((LdaGibbsParA)LdaGibbsParA.this).z[m][n] = kk2;
                    try {
                        int[] nArray5 = LdaGibbsParA.this.nmk[m];
                        int n5 = kk2;
                        nArray5[n5] = nArray5[n5] + 1;
                    }
                    catch (Exception e) {
                        System.out.println(Vectors.print(this.pp));
                    }
                    int[] nArray6 = LdaGibbsParA.this.nkt[kk2];
                    synchronized (nArray6) {
                        int[] nArray7 = LdaGibbsParA.this.nkt[kk2];
                        int n6 = t;
                        nArray7[n6] = nArray7[n6] + 1;
                        int[] nArray8 = LdaGibbsParA.this.nk;
                        int n7 = kk2;
                        nArray8[n7] = nArray8[n7] + 1;
                    }
                    ++n;
                }
            }
        };
        this.iter = 0;
        while (this.iter < niter) {
            System.out.println(this.iter);
            eachdoc.loop(this.M);
            ++this.iter;
        }
        eachdoc.shutdown();
    }

    @Override
    public void runq(int niter) {
        ParallelFor eachdoc = new ParallelFor(this.P){
            double[][] pp;
            {
                this.pp = new double[LdaGibbsParA.this.P][LdaGibbsParA.this.K];
            }

            @Override
            public void process(int m, int thread) {
                int n = 0;
                while (n < LdaGibbsParA.this.wq[m].length) {
                    int k = LdaGibbsParA.this.zq[m][n];
                    int t = LdaGibbsParA.this.wq[m][n];
                    int[] nArray = LdaGibbsParA.this.nmkq[m];
                    int n2 = k;
                    nArray[n2] = nArray[n2] - 1;
                    double psum = 0.0;
                    int kk = 0;
                    while (kk < LdaGibbsParA.this.K) {
                        this.pp[thread][kk] = ((double)LdaGibbsParA.this.nmkq[m][kk] + LdaGibbsParA.this.alpha) * LdaGibbsParA.this.phi[kk][t];
                        psum += this.pp[thread][kk];
                        ++kk;
                    }
                    double u = LdaGibbsParA.this.rand_[thread].nextDouble() * psum;
                    psum = 0.0;
                    int kk2 = 0;
                    kk2 = 0;
                    while (kk2 < LdaGibbsParA.this.K) {
                        if (u <= (psum += this.pp[thread][kk2])) break;
                        ++kk2;
                    }
                    ((LdaGibbsParA)LdaGibbsParA.this).zq[m][n] = kk2;
                    try {
                        int[] nArray2 = LdaGibbsParA.this.nmkq[m];
                        int n3 = kk2;
                        nArray2[n3] = nArray2[n3] + 1;
                    }
                    catch (Exception e) {
                        System.out.println(Vectors.print(this.pp));
                    }
                    ++n;
                }
            }
        };
        int qiter = 0;
        while (qiter < niter) {
            eachdoc.loop(this.wq.length);
            ++qiter;
        }
        eachdoc.shutdown();
    }

    @Override
    public double ppx() {
        double loglik = 0.0;
        double[][] thetaq = new double[this.Mq][this.K];
        int m = 0;
        while (m < this.Mq) {
            int k = 0;
            while (k < this.K) {
                thetaq[m][k] = ((double)this.nmkq[m][k] + this.alpha) / ((double)this.wq[m].length + (double)this.K * this.alpha);
                ++k;
            }
            ++m;
        }
        m = 0;
        while (m < this.Mq) {
            int n = 0;
            while (n < this.wq[m].length) {
                double sum = 0.0;
                int k = 0;
                while (k < this.K) {
                    sum += thetaq[m][k] * this.phi[k][this.wq[m][n]];
                    ++k;
                }
                loglik += Math.log(sum);
                ++n;
            }
            ++m;
        }
        return Math.exp(-loglik / (double)this.Wq);
    }

    public String toString() {
        return String.format("LDA ParA: M = %d, K = %d, V = %d, alpha = %2.5f, beta = %2.5f, P = %d", this.M, this.K, this.V, this.alpha, this.beta, this.P);
    }
}

