/*
 * Decompiled with CFR 0.152.
 */
package org.knowceans.util;

import org.knowceans.util.ArmSampler;
import org.knowceans.util.Gamma;
import org.knowceans.util.RandomSamplers;
import org.knowceans.util.Vectors;

public class DirichletEstimation {
    public static double[] estimateAlpha(double[][] pp) {
        double[] suffstats = DirichletEstimation.suffStats(pp);
        double[] pmean = DirichletEstimation.guessMean(pp);
        double[] alpha = DirichletEstimation.guessAlpha(pp, pmean);
        boolean newton = false;
        if (newton) {
            DirichletEstimation.alphaNewton(pp.length, suffstats, alpha);
        } else {
            DirichletEstimation.alphaFixedPoint(suffstats, alpha);
        }
        return alpha;
    }

    public static double[] estimateMeanPrec(double[][] pp) {
        double[] mean = DirichletEstimation.guessMean(pp);
        double[] meansq = DirichletEstimation.colMoments(pp, 2);
        double prec = DirichletEstimation.guessPrecision(mean, meansq);
        double[] suffstats = DirichletEstimation.suffStats(pp);
        int i = 0;
        while (i < 5) {
            prec = DirichletEstimation.precFixedPoint(suffstats, pp.length, mean, prec);
            DirichletEstimation.meanGenNewton(suffstats, mean, prec);
            ++i;
        }
        double[] retval = new double[mean.length];
        System.arraycopy(mean, 0, retval, 0, mean.length - 1);
        retval[mean.length - 1] = prec;
        return retval;
    }

    public static double[] estimateAlpha(int[][] nmk) {
        double[][] pmk = new double[nmk.length][];
        int K = nmk[0].length;
        int m = 0;
        while (m < pmk.length) {
            int nm = Vectors.sum(nmk[m]);
            pmk[m] = new double[K];
            int k = 0;
            while (k < K) {
                pmk[m][k] = (double)nmk[m][k] / (double)nm;
                ++k;
            }
            ++m;
        }
        return DirichletEstimation.estimateAlpha(pmk);
    }

    public static int estimateAlphaLoo(double[] alpha, int[][] nmk) {
        int[] nm = new int[nmk.length];
        double limdist = 1.0E-9;
        int iter = 20000;
        double[] alphanew = new double[alpha.length];
        alphanew = Vectors.copy(alpha);
        int m = 0;
        while (m < nm.length) {
            nm[m] = Vectors.sum(nmk[m]);
            ++m;
        }
        int i = 0;
        while (i < iter) {
            double sumalpha = Vectors.sum(alpha);
            double diffalpha = 0.0;
            double den = 0.0;
            int m2 = 0;
            while (m2 < nm.length) {
                den += (double)nm[m2] / ((double)(nm[m2] - 1) + sumalpha);
                ++m2;
            }
            int k = 0;
            while (k < alpha.length) {
                double num = 0.0;
                int m3 = 0;
                while (m3 < nm.length) {
                    num += (double)nmk[m3][k] / ((double)(nmk[m3][k] - 1) + alpha[k]);
                    ++m3;
                }
                alphanew[k] = alpha[k] * num / den;
                diffalpha += Math.abs(alpha[k] - alphanew[k]);
                ++k;
            }
            if (diffalpha < limdist) {
                return i;
            }
            alpha = Vectors.copy(alphanew);
            ++i;
        }
        return iter;
    }

    public static double[] estimateMeanPrec(int[][] nn) {
        double[] retval = null;
        return retval;
    }

    public static double getPrec(double[] meanPrec) {
        return meanPrec[meanPrec.length - 1];
    }

    public static double[] getMean(double[] meanPrec) {
        double[] retval = new double[meanPrec.length];
        System.arraycopy(meanPrec, 0, retval, 0, meanPrec.length - 1);
        double sum = 0.0;
        int k = 0;
        while (k < meanPrec.length - 1) {
            sum += meanPrec[k];
            ++k;
        }
        retval[meanPrec.length - 1] = 1.0 - sum;
        return retval;
    }

    public static double[] getAlpha(double[] meanPrec) {
        double[] aa = DirichletEstimation.getMean(meanPrec);
        Vectors.mult(aa, DirichletEstimation.getPrec(meanPrec));
        return aa;
    }

    public static double[] guessAlpha(double[][] pp, double[] pmean) {
        int K = pp[0].length;
        double[] pmeansq = DirichletEstimation.colMoments(pp, 2);
        double[] alpha = Vectors.copy(pmean);
        double precision = DirichletEstimation.guessPrecision(pmean, pmeansq);
        precision /= (double)K;
        int k = 0;
        while (k < K) {
            alpha[k] = pmean[k] * precision;
            ++k;
        }
        return alpha;
    }

    public static double[] guessMean(double[][] pp) {
        return DirichletEstimation.colMoments(pp, 1);
    }

    private static void meanGenNewton(double[] suffstats, double[] mean, double prec) {
        double[] alpha = new double[mean.length];
        int i = 0;
        while (i < 100) {
            int k = 0;
            while (k < mean.length) {
                int j = 0;
                while (j < alpha.length) {
                    int n = k;
                    alpha[n] = alpha[n] + mean[j] * (suffstats[j] - Gamma.digamma(prec * mean[j]));
                    ++j;
                }
                alpha[k] = Gamma.invdigamma(suffstats[k] - alpha[k]);
                ++k;
            }
            double sumalpha = Vectors.sum(alpha);
            int k2 = 0;
            while (k2 < alpha.length) {
                mean[k2] = alpha[k2] / sumalpha;
                ++k2;
            }
            ++i;
        }
    }

    private static double precFixedPoint(double[] suffstats, int N, double[] mean, double prec) {
        double dloglik = 0.0;
        int k = 0;
        while (k < mean.length) {
            dloglik += mean[k] * (Gamma.digamma(prec * mean[k]) + suffstats[k]);
            ++k;
        }
        dloglik = (double)N * (Gamma.digamma(prec) - dloglik);
        double ddloglik = 0.0;
        int k2 = 0;
        while (k2 < mean.length) {
            ddloglik += mean[k2] * mean[k2] * Gamma.trigamma(prec * mean[k2]);
            ++k2;
        }
        ddloglik = (double)N * (Gamma.trigamma(prec) - dloglik);
        double precinv = 1.0 / prec + dloglik / (prec * prec * ddloglik);
        return 1.0 / precinv;
    }

    public static double[] guessAlpha(int[][] nmk) {
        double[][] pmk = new double[nmk.length][];
        int K = nmk[0].length;
        int m = 0;
        while (m < pmk.length) {
            int nm = Vectors.sum(nmk[m]);
            pmk[m] = new double[K];
            int k = 0;
            while (k < K) {
                pmk[m][k] = (double)nmk[m][k] / (double)nm;
                ++k;
            }
            ++m;
        }
        double[] pk = DirichletEstimation.guessMean(pmk);
        return DirichletEstimation.guessAlpha(pmk, pk);
    }

    public static double[] guessAlphaDirect(int[][] nmk, int[] nm) {
        int K = nmk[0].length;
        int M = nm.length;
        double[] pmean = new double[K];
        double[] pmk = new double[M];
        double prec = 0.0;
        int k = 0;
        while (k < K) {
            int m = 0;
            while (m < M) {
                pmk[m] = (double)nmk[m][k] / (double)nm[m];
                int n = k;
                pmean[n] = pmean[n] + pmk[m];
                ++m;
            }
            int n = k;
            pmean[n] = pmean[n] / (double)M;
            if (k < K - 1) {
                double pvark = 0.0;
                m = 0;
                while (m < M) {
                    double diff = pmk[m] - pmean[k];
                    pvark += diff * diff;
                    ++m;
                }
                prec += Math.log(pmean[k] * (1.0 - pmean[k]) / (pvark /= (double)M) - 1.0);
            }
            ++k;
        }
        prec = Math.exp((double)(1 / (K - 1)) * prec);
        Vectors.mult(pmean, prec);
        return pmean;
    }

    public static double guessPrecision(double[] pmean, double[] pmeansq) {
        double precision = 0.0;
        int K = pmean.length;
        int k = 0;
        while (k < K) {
            precision += (pmean[k] - pmeansq[k]) / (pmeansq[k] - pmean[k] * pmean[k]);
            ++k;
        }
        return precision / (double)pmean.length;
    }

    private static double[] colMoments(double[][] xx, int order) {
        int K = xx[0].length;
        int N = xx.length;
        double[] pmean2 = new double[K];
        int i = 0;
        while (i < N) {
            int k = 0;
            while (k < K) {
                double element = xx[i][k];
                int d = 1;
                while (d < order) {
                    element *= element;
                    ++d;
                }
                int n = k++;
                pmean2[n] = pmean2[n] + element;
            }
            ++i;
        }
        int k = 0;
        while (k < K) {
            int n = k++;
            pmean2[n] = pmean2[n] / (double)N;
        }
        return pmean2;
    }

    private static double[] suffStats(double[][] pp) {
        int K = pp[0].length;
        int N = pp.length;
        double eps = 1.0E-6;
        double[] suffstats = new double[K];
        int i = 0;
        while (i < N) {
            int k = 0;
            while (k < K) {
                int n = k;
                suffstats[n] = suffstats[n] + Math.log(pp[i][k] + eps);
                ++k;
            }
            ++i;
        }
        int k = 0;
        while (k < K) {
            int n = k++;
            suffstats[n] = suffstats[n] / (double)N;
        }
        return suffstats;
    }

    public static void alphaNewton(int N, double[] suffstats, double[] alpha) {
        int K = alpha.length;
        double loglik = 0.0;
        double loglikold = 0.0;
        double[] grad = new double[K];
        double alphasum = Vectors.sum(alpha);
        double[] alphaold = new double[K];
        double lgasum = 0.0;
        double asssum = 0.0;
        int iterations = 1000;
        double epsilon = 1.0E-6;
        int i = 0;
        while (i < iterations) {
            System.arraycopy(alpha, 0, alphaold, 0, K);
            int k = 0;
            while (k < K) {
                lgasum += Gamma.lgamma(alpha[k]);
                asssum += (alpha[k] - 1.0) * suffstats[k];
                grad[k] = (double)N * (Gamma.digamma(alphasum) - Gamma.digamma(alpha[k]) + suffstats[k]);
                ++k;
            }
            loglik = (double)N * (Gamma.lgamma(alphasum) - lgasum + asssum);
            if (Math.abs(loglikold - loglik) < epsilon) break;
            loglikold = loglik;
            double[] hinvg = new double[K];
            double[] qdiag = new double[K];
            double bnum = 0.0;
            double bden = 0.0;
            double z = (double)N * Gamma.trigamma(alphasum);
            int k2 = 0;
            while (k2 < K) {
                qdiag[k2] = (double)(-N) * Gamma.trigamma(alpha[k2]);
                bnum += grad[k2] / qdiag[k2];
                bden += 1.0 / qdiag[k2];
                ++k2;
            }
            double b = bnum / (1.0 / z + bden);
            int k3 = 0;
            while (k3 < K) {
                hinvg[k3] = (grad[k3] - b) / qdiag[k3];
                int n = k3;
                alpha[n] = alpha[n] - hinvg[k3];
                ++k3;
            }
            ++i;
        }
    }

    public static int alphaFixedPoint(double[] suffstats, double[] alpha) {
        int K = alpha.length;
        double maxdiff = 1.0E-4;
        int maxiter = 100;
        int i = 0;
        while (i < maxiter) {
            double alphadiff = 0.0;
            double sumalpha = Vectors.sum(alpha);
            int k = 0;
            while (k < K) {
                alpha[k] = Gamma.invdigamma(Gamma.digamma(sumalpha) + suffstats[k]);
                alphadiff = Math.max(alphadiff, Math.abs(alpha[k] - alphadiff));
                ++k;
            }
            if (alphadiff < maxdiff) {
                return i;
            }
            ++i;
        }
        return maxiter;
    }

    public static double estimateAlphaMomentMatch(int[][] nmk, int[] nm) {
        double precision = 0.0;
        int K = nmk[0].length;
        int M = nmk.length;
        double eps = 1.0E-6;
        double pmeank = 0.0;
        double pmeansqk = 0.0;
        int k = 0;
        while (k < K) {
            int m = 0;
            while (m < M) {
                double pmk = (double)nmk[m][k] / (double)nm[m];
                pmeank += pmk;
                pmeansqk += pmk * pmk;
                ++m;
            }
            precision += ((pmeank /= (double)M) - (pmeansqk /= (double)M)) / (pmeansqk - pmeank * pmeank + eps);
            ++k;
        }
        return precision / (double)(K * K);
    }

    public static double estimateAlphaMap(int[][] nmk, int[] nm, double alpha, double a, double b) {
        int iter = 200;
        int M = nmk.length;
        int K = nmk[0].length;
        double alpha0 = 0.0;
        double prec = 1.0E-5;
        int i = 0;
        while (i < iter) {
            double summk = 0.0;
            double summ = 0.0;
            int m = 0;
            while (m < M) {
                summ += Gamma.digamma((double)K * alpha + (double)nm[m]);
                int k = 0;
                while (k < K) {
                    summk += Gamma.digamma(alpha + (double)nmk[m][k]);
                    ++k;
                }
                ++m;
            }
            summ -= (double)M * Gamma.digamma((double)K * alpha);
            if (Math.abs((alpha = (a - 1.0 + alpha * (summk -= (double)(M * K) * Gamma.digamma(alpha))) / (b + (double)K * summ)) - alpha0) < prec) {
                return alpha;
            }
            alpha0 = alpha;
            ++i;
        }
        return alpha;
    }

    public static double estimateAlphaMapSub(int[][] nmk, int[] nm, int[] mrows, double alpha, double a, double b) {
        int iter = 200;
        int M = mrows.length;
        int K = nmk[0].length;
        double alpha0 = 0.0;
        double prec = 1.0E-5;
        int i = 0;
        while (i < iter) {
            double summk = 0.0;
            double summ = 0.0;
            int m = 0;
            while (m < M) {
                summ += Gamma.digamma((double)K * alpha + (double)nm[mrows[m]]);
                int k = 0;
                while (k < K) {
                    summk += Gamma.digamma(alpha + (double)nmk[mrows[m]][k]);
                    ++k;
                }
                ++m;
            }
            summ -= (double)M * Gamma.digamma((double)K * alpha);
            if (Math.abs((alpha = (a - 1.0 + alpha * (summk -= (double)(M * K) * Gamma.digamma(alpha))) / (b + (double)K * summ)) - alpha0) < prec) {
                return alpha;
            }
            alpha0 = alpha;
            ++i;
        }
        return alpha;
    }

    public static double[][] estimateAlphaMapSub(int[][] nmk, int[] nm, int[] m2j, double[][] alphajk, double a, double b) {
        int iter = 200;
        double prec = 1.0E-5;
        int M = nmk.length;
        int J = alphajk.length;
        int K = alphajk[0].length;
        double[][] alphanew = new double[J][K];
        int j = 0;
        while (j < J) {
            int i = 0;
            while (i < iter) {
                double sumalpha = Vectors.sum(alphajk[j]);
                int k = 0;
                while (k < K) {
                    double summk = 0.0;
                    double summ = 0.0;
                    int m = 0;
                    while (m < M) {
                        if (m2j[m] == j) {
                            summk += Gamma.digamma((double)nmk[m][k] + alphajk[j][k]);
                            summ += Gamma.digamma((double)nm[m] + sumalpha);
                        }
                        ++m;
                    }
                    alphanew[j][k] = alphajk[j][k] * (a + (summk -= (double)M * Gamma.digamma(alphajk[j][k]))) / (b / (double)K + (summ -= (double)M * Gamma.digamma(sumalpha)));
                    ++k;
                }
                if (Vectors.sqdist(alphanew[j], alphajk[j]) < prec) break;
                alphajk[j] = Vectors.copy(alphanew[j]);
                ++i;
            }
            ++j;
        }
        return alphajk;
    }

    public static double[] estimateAlphaMap(int[][] nmk, int[] nm, double[] alpha, double a, double b) {
        int iter = 200;
        double prec = 1.0E-5;
        int M = nmk.length;
        int K = alpha.length;
        double[] alphanew = new double[K];
        int i = 0;
        while (i < iter) {
            double sumalpha = Vectors.sum(alpha);
            int k = 0;
            while (k < K) {
                double summk = 0.0;
                double summ = 0.0;
                int m = 0;
                while (m < M) {
                    summk += Gamma.digamma((double)nmk[m][k] + alpha[k]);
                    summ += Gamma.digamma((double)nm[m] + sumalpha);
                    ++m;
                }
                alphanew[k] = alpha[k] * (a + (summk -= (double)M * Gamma.digamma(alpha[k]))) / (b / (double)K + (summ -= (double)M * Gamma.digamma(sumalpha)));
                ++k;
            }
            if (Vectors.sqdist(alphanew, alpha) < prec) {
                return alphanew;
            }
            alpha = Vectors.copy(alphanew);
            ++i;
        }
        return alpha;
    }

    private static double[] sampleAlphaArms(int[][] nmk, double a, double b, int n) throws Exception {
        GammaPolyaParams params = new GammaPolyaParams();
        params.nmk = nmk;
        params.a = a;
        params.b = b;
        GammaPolyaArms arms = new GammaPolyaArms();
        double[] xprev = new double[]{5.0};
        double[] xl = new double[]{1.0E-5};
        double[] xr = new double[]{10.0};
        double[] alpha = new double[n];
        arms.hist(params, 0.1, 2.8, 200);
        int i = 0;
        while (i < n) {
            alpha[i] = arms.armsSimple(params, 25, xl, xr, true, xprev);
            ++i;
        }
        return alpha;
    }

    public static void main(String[] args) throws Exception {
        DirichletEstimation.testPolya();
    }

    public static void testPolya() {
        RandomSamplers rs = new RandomSamplers();
        int M = 10000;
        int K = 5;
        boolean N = false;
        int N0 = 100;
        double[] alpha = new double[]{0.35, 0.35, 0.05, 0.24, 0.31};
        System.out.println("original alpha");
        System.out.println(Vectors.print(alpha));
        int[][] nmk = new int[M][K];
        int[] nm = new int[M];
        int W = 0;
        int m = 0;
        while (m < nmk.length) {
            nm[m] = (int)((double)N0 + (double)N * Math.random());
            nmk[m] = rs.randMultFreqs(rs.randDir(alpha), nm[m]);
            W += nm[m];
            ++m;
        }
        int[] nk = new int[K];
        int k = 0;
        while (k < nk.length) {
            int m2 = 0;
            while (m2 < M) {
                int n = k;
                nk[n] = nk[n] + nmk[m2][k];
                ++m2;
            }
            ++k;
        }
        System.out.println("sample counts in categories");
        System.out.println(Vectors.print(nk));
        alpha = DirichletEstimation.guessAlpha(nmk);
        System.out.println("estimated alpha from counts (moments matching)");
        System.out.println(Vectors.print(alpha));
        alpha = DirichletEstimation.estimateAlpha(nmk);
        System.out.println("estimated alpha from counts (fixed point via Dirichlet Eq. 9)");
        System.out.println(Vectors.print(alpha));
        System.out.println("guess alpha via Polya moment match");
        alpha = DirichletEstimation.guessAlphaDirect(nmk, nm);
        System.out.println(Vectors.print(alpha));
        System.out.println("estimated scalar alpha from counts via precision (moment matching)");
        System.out.println(DirichletEstimation.estimateAlphaMomentMatch(nmk, nm));
        System.out.println("estimated scalar alpha from counts via MAP estimator");
        System.out.println(DirichletEstimation.estimateAlphaMap(nmk, nm, 0.1, 0.5, 0.5));
        System.out.println("estimated vector alpha from counts via MAP estimator");
        double[] astart = Vectors.ones(K, 0.1);
        System.out.println(Vectors.print(DirichletEstimation.estimateAlphaMap(nmk, nm, astart, 0.5, 0.5)));
    }

    public static void testDirichlet() {
        double[] alpha = new double[]{1.0, 1.9, 1.9};
        System.out.println("original alpha");
        System.out.println(Vectors.print(alpha));
        RandomSamplers rs = new RandomSamplers();
        double[][] pp = rs.randDir(alpha, 100000);
        double[] alphaguess = DirichletEstimation.estimateAlpha(pp);
        System.out.println("estimated alpha");
        System.out.println(Vectors.print(alphaguess));
        System.out.println("guessed mean");
        double[] mean = DirichletEstimation.guessMean(pp);
        double[] suffstats = DirichletEstimation.suffStats(pp);
        System.out.println(Vectors.print(mean));
        System.out.println("guessed mean (Newton)");
        DirichletEstimation.meanGenNewton(suffstats, mean, 2.5);
        System.out.println(Vectors.print(mean));
        System.out.println("estimated precision, mean and alpha");
        double[] mp = DirichletEstimation.estimateMeanPrec(pp);
        System.out.println(DirichletEstimation.getPrec(mp));
        System.out.println(Vectors.print(DirichletEstimation.getMean(mp)));
        System.out.println(Vectors.print(DirichletEstimation.getAlpha(mp)));
    }

    static class GammaPolyaArms
    extends ArmSampler {
        GammaPolyaArms() {
        }

        @Override
        public double logpdf(double alpha, Object params) {
            double logpaz = 0.0;
            GammaPolyaParams gpp = (GammaPolyaParams)params;
            int m = 0;
            while (m < gpp.nmk.length) {
                logpaz += Gamma.ldelta(gpp.nmk[m], alpha);
                logpaz -= Gamma.ldelta(gpp.nmk[m].length, alpha);
                ++m;
            }
            return logpaz;
        }

        public void hist(Object params, double amin, double amax, int n) {
            double adiff = (amax - amin) / (double)(n + 1);
            GammaPolyaParams gpp = (GammaPolyaParams)params;
            double a = amin;
            int i = 0;
            while (i < n) {
                double pa = this.logpdf(a += adiff, params);
                System.out.println(String.valueOf(a) + "\t" + pa);
                ++i;
            }
        }
    }

    static class GammaPolyaParams {
        int[][] nmk;
        double a;
        double b;

        GammaPolyaParams() {
        }
    }
}

