/*
 *  Mathlib : A C Library of Special Functions
 *  Copyright (C) 2000-2020 The R Core Team
 *  Copyright (C) 2005-2020 The R Foundation
 *  Copyright (C) 1998 Ross Ihaka
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, a copy is available at
 *  https://www.R-project.org/Licenses/
 *
 *  SYNOPSIS
 *
 *    #include <Rmath.h>
 *    double rhyper(double NR, double NB, double n);
 *
 *  DESCRIPTION
 *
 *    Random variates from the hypergeometric distribution.
 *    Returns the number of white balls drawn when kk balls
 *    are drawn at random from an urn containing nn1 white
 *    and nn2 black balls.
 *
 *  REFERENCE
 *
 *    V. Kachitvichyanukul and B. Schmeiser (1985).
 *    ``Computer generation of hypergeometric random variates,''
 *    Journal of Statistical Computation and Simulation 22, 127-145.
 *
 *    The original algorithm had a bug -- R bug report PR#7314 --
 *    giving numbers slightly too small in case III h2pe
 *    where (m < 100 || ix <= 50) , see below.
 */

#include "nmath.h"
#include "dpq.h"
#include <limits.h>

// afc(i) :=  ln( i! )	[logarithm of the factorial i] = {R:} lgamma(i + 1) = {C:} lgammafn(i + 1)
static double afc(int i)
{
    // If (i > 7), use Stirling's approximation, otherwise use table lookup.
    const static double al[8] =
    {
	0.0,/*ln(0!)=ln(1)*/
	0.0,/*ln(1!)=ln(1)*/
	0.69314718055994530941723212145817,/*ln(2) */
	1.79175946922805500081247735838070,/*ln(6) */
	3.17805383034794561964694160129705,/*ln(24)*/
	4.78749174278204599424770093452324,
	6.57925121201010099506017829290394,
	8.52516136106541430016553103634712
	/* 10.60460290274525022841722740072165, approx. value below =
	   10.6046028788027; rel.error = 2.26 10^{-9}

	  FIXME: Use constants and if(n > ..) decisions from ./stirlerr.c
	  -----  will be even *faster* for n > 500 (or so)
	*/
    };

    if (i < 0) {
	MATHLIB_WARNING(("rhyper.c: afc(i), i=%d < 0 -- SHOULD NOT HAPPEN!\n"), i);
	return -1; // unreached
    }
    if (i <= 7)
	return al[i];
    // else i >= 8 :
    double di = i, i2 = di*di;
    return (di + 0.5) * log(di) - di + M_LN_SQRT_2PI +
	(0.0833333333333333 - 0.00277777777777778 / i2) / di;
}

//     rhyper(NR, NB, n) -- NR 'red', NB 'blue', n drawn, how many are 'red'
double rhyper(double nn1in, double nn2in, double kkin)
{
    /* extern double afc(int); */

    int nn1, nn2, kk;
    int ix; // return value (coerced to double at the very end)
    Rboolean setup1, setup2;

    /* These should become 'thread_local globals' : */
    static int ks = -1, n1s = -1, n2s = -1;
    static int m, minjx, maxjx;
    static int k, n1, n2; // <- not allowing larger integer par
    static double N;

    // II :
    static double w;
    // III:
    static double a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;

    /* check parameter validity */

    if(!R_FINITE(nn1in) || !R_FINITE(nn2in) || !R_FINITE(kkin))
	ML_WARN_return_NAN;

    nn1in = R_forceint(nn1in);
    nn2in = R_forceint(nn2in);
    kkin  = R_forceint(kkin);

    if (nn1in < 0 || nn2in < 0 || kkin < 0 || kkin > nn1in + nn2in)
	ML_WARN_return_NAN;
    if (nn1in >= INT_MAX || nn2in >= INT_MAX || kkin >= INT_MAX) {
	/* large n -- evade integer overflow (and inappropriate algorithms)
	   -------- */
        // FIXME: Much faster to give rbinom() approx when appropriate; -> see Kuensch(1989)
	// Johnson, Kotz,.. p.258 (top) mention the *four* different binomial approximations
	if(kkin == 1.) { // Bernoulli
	    return rbinom(kkin, nn1in / (nn1in + nn2in));
	}
	// Slow, but safe: return  F^{-1}(U)  where F(.) = phyper(.) and  U ~ U[0,1]
	return qhyper(unif_rand(), nn1in, nn2in, kkin,
		      /*lower_tail =*/ FALSE, /*log_p = */ FALSE);
	// lower_tail=FALSE: a thinko, is still "correct" as equiv. to  U <--> 1-U
    }
    nn1 = (int)nn1in;
    nn2 = (int)nn2in;
    kk  = (int)kkin;

    /* if new parameter values, initialize */
    if (nn1 != n1s || nn2 != n2s) { // n1 | n2 is changed: setup all
	setup1 = TRUE;	setup2 = TRUE;
    } else if (kk != ks) { // n1 & n2 are unchanged: setup 'k' only
	setup1 = FALSE;	setup2 = TRUE;
    } else { // all three unchanged ==> no setup
	setup1 = FALSE;	setup2 = FALSE;
    }
    if (setup1) { // n1 & n2
	n1s = nn1; n2s = nn2; // save
	N = nn1 + (double)nn2; // avoid int overflow
	if (nn1 <= nn2) {
	    n1 = nn1; n2 = nn2;
	} else { // nn2 < nn1
	    n1 = nn2; n2 = nn1;
	}
	// now have n1 <= n2
    }
    if (setup2) { // k
	ks = kk; // save
	if ((double)kk + kk >= N) { // this could overflow
	    k = (int)(N - kk);
	} else {
	    k = kk;
	}
    }
    if (setup1 || setup2) {
	m = (int) ((k + 1.) * (n1 + 1.) / (N + 2.)); // m := floor(adjusted mean E[.])
	minjx = imax2(0, k - n2);
	maxjx = imin2(n1, k);
#ifdef DEBUG_rhyper
	REprintf("rhyper(n1=%d, n2=%d, k=%d), setup: floor(a.mean)=: m = %d, [min,maxjx]= [%d,%d]\n",
		 nn1, nn2, kk, m, minjx, maxjx);
#endif
    }
    /* generate random variate --- Three basic cases */

    if (minjx == maxjx) { /* I: degenerate distribution ---------------- */
#ifdef DEBUG_rhyper
	REprintf("rhyper(), branch I (degenerate): ix := maxjx = %d\n", maxjx);
#endif
	ix = maxjx;
	goto L_finis; // return appropriate variate

    } else if (m - minjx < 10) { // II: (Scaled) algorithm HIN (inverse transformation) ----
	const static double scale = 1e25; // scaling factor against (early) underflow
	const static double con = 57.5646273248511421;
					  // 25*log(10) = log(scale) { <==> exp(con) == scale }
	if (setup1 || setup2) {
	    double lw; // log(w);  w = exp(lw) * scale = exp(lw + log(scale)) = exp(lw + con)
	    if (k < n2) {
		lw = afc(n2) + afc(n1 + n2 - k) - afc(n2 - k) - afc(n1 + n2);
	    } else {
		lw = afc(n1) + afc(     k     ) - afc(k - n2) - afc(n1 + n2);
	    }
	    w = exp(lw + con);
	}
	double p, u;
#ifdef DEBUG_rhyper
	REprintf("rhyper(), branch II; w = %g > 0\n", w);
#endif
      L10:
	p = w;
	ix = minjx;
	u = unif_rand() * scale;
#ifdef DEBUG_rhyper
	REprintf("  _new_ u = %g\n", u);
#endif
	while (u > p) {
	    u -= p;
	    p *= ((double) n1 - ix) * (k - ix);
	    ix++;
	    p = p / ix / (n2 - k + ix);
#ifdef DEBUG_rhyper
	    REprintf("       ix=%3d, u=%11g, p=%20.14g (u-p=%g)\n", ix, u, p, u-p);
#endif
	    if (ix > maxjx)
		goto L10;
	    // FIXME  if(p == 0.)  we also "have lost"  => goto L10
	}

    } else { /* III : H2PE Algorithm --------------------------------------- */

	double u,v;

	if (setup1 || setup2) {
	    s = sqrt((N - k) * k * n1 * n2 / (N - 1) / N / N);

	    /* remark: d is defined in reference without int. */
	    /* the truncation centers the cell boundaries at 0.5 */

	    d = (int) (1.5 * s) + .5;
	    xl = m - d + .5;
	    xr = m + d + .5;
	    a = afc(m) + afc(n1 - m) + afc(k - m) + afc(n2 - k + m);
	    kl = exp(a - afc((int) (xl)) - afc((int) (n1 - xl))
		     - afc((int) (k - xl))
		     - afc((int) (n2 - k + xl)));
	    kr = exp(a - afc((int) (xr - 1))
		     - afc((int) (n1 - xr + 1))
		     - afc((int) (k - xr + 1))
		     - afc((int) (n2 - k + xr - 1)));
	    lamdl = -log(xl * (n2 - k + xl) / (n1 - xl + 1) / (k - xl + 1));
	    lamdr = -log((n1 - xr + 1) * (k - xr + 1) / xr / (n2 - k + xr));
	    p1 = d + d;
	    p2 = p1 + kl / lamdl;
	    p3 = p2 + kr / lamdr;
	}
#ifdef DEBUG_rhyper
	REprintf("rhyper(), branch III {accept/reject}: (xl,xr)= (%g,%g); (lamdl,lamdr)= (%g,%g)\n",
		 xl, xr, lamdl,lamdr);
	REprintf("-------- p123= c(%g,%g,%g)\n", p1,p2, p3);
#endif
	int n_uv = 0;
      L30:
	u = unif_rand() * p3;
	v = unif_rand();
	n_uv++;
	if(n_uv >= 10000) {
	    REprintf("rhyper(*, n1=%d, n2=%d, k=%d): branch III: giving up after %d rejections\n",
		     nn1, nn2, kk, n_uv);
	    ML_WARN_return_NAN;
        }
#ifdef DEBUG_rhyper
	REprintf(" ... L30 [%d]: new (u=%g, v ~ U[0,1]=%g): ", n_uv, u,v);
#endif

	if (u < p1) {		/* rectangular region */
	    ix = (int) (xl + u);
	} else if (u <= p2) {	/* left tail */
	    ix = (int) (xl + log(v) / lamdl);
	    if (ix < minjx)
		goto L30;
	    v = v * (u - p1) * lamdl;
	} else {		/* right tail */
	    ix = (int) (xr - log(v) / lamdr);
	    if (ix > maxjx)
		goto L30;
	    v = v * (u - p2) * lamdr;
	}

	/* acceptance/rejection test */
	Rboolean reject = TRUE;

	if (m < 100 || ix <= 50) {
	    /* explicit evaluation */
	    /* The original algorithm (and TOMS 668) have
		   f = f * i * (n2 - k + i) / (n1 - i) / (k - i);
	       in the (m > ix) case, but the definition of the
	       recurrence relation on p134 shows that the +1 is
	       needed. */
	    int i;
	    double f = 1.0;
	    if (m < ix) {
		for (i = m + 1; i <= ix; i++)
		    f = f * (n1 - i + 1) * (k - i + 1) / (n2 - k + i) / i;
	    } else if (m > ix) {
		for (i = ix + 1; i <= m; i++)
		    f = f * i * (n2 - k + i) / (n1 - i + 1) / (k - i + 1);
	    }
	    if (v <= f) {
		reject = FALSE;
	    }
	} else {

	    const static double deltal = 0.0078;
	    const static double deltau = 0.0034;

	    double e, g, r, t, y;
	    double de, dg, dr, ds, dt, gl, gu, nk, nm, ub;
	    double xk, xm, xn, y1, ym, yn, yk, alv;

#ifdef DEBUG_rhyper
	    REprintf(" ... accept/reject 'large' case v=%g\n", v);
#endif
	    /* squeeze using upper and lower bounds */
	    y = ix;
	    y1 = y + 1.0;
	    ym = y - m;
	    yn = n1 - y + 1.0;
	    yk = k - y + 1.0;
	    nk = n2 - k + y1;
	    r = -ym / y1;
	    s = ym / yn;
	    t = ym / yk;
	    e = -ym / nk;
	    g = yn * yk / (y1 * nk) - 1.0;
	    dg = 1.0;
	    if (g < 0.0)
		dg = 1.0 + g;
	    gu = g * (1.0 + g * (-0.5 + g / 3.0));
	    gl = gu - .25 * (g * g * g * g) / dg;
	    xm = m + 0.5;
	    xn = n1 - m + 0.5;
	    xk = k - m + 0.5;
	    nm = n2 - k + xm;
	    ub = y * gu - m * gl + deltau
		+ xm * r * (1. + r * (-0.5 + r / 3.0))
		+ xn * s * (1. + s * (-0.5 + s / 3.0))
		+ xk * t * (1. + t * (-0.5 + t / 3.0))
		+ nm * e * (1. + e * (-0.5 + e / 3.0));
	    /* test against upper bound */
	    alv = log(v);
	    if (alv > ub) {
		reject = TRUE;
	    } else {
				/* test against lower bound */
		dr = xm * (r * r * r * r);
		if (r < 0.0)
		    dr /= (1.0 + r);
		ds = xn * (s * s * s * s);
		if (s < 0.0)
		    ds /= (1.0 + s);
		dt = xk * (t * t * t * t);
		if (t < 0.0)
		    dt /= (1.0 + t);
		de = nm * (e * e * e * e);
		if (e < 0.0)
		    de /= (1.0 + e);
		if (alv < ub - 0.25 * (dr + ds + dt + de)
		    + (y + m) * (gl - gu) - deltal) {
		    reject = FALSE;
		}
		else {
		    /* * Stirling's formula to machine accuracy
		     */
		    if (alv <= (a - afc(ix) - afc(n1 - ix)
				- afc(k - ix) - afc(n2 - k + ix))) {
			reject = FALSE;
		    } else {
			reject = TRUE;
		    }
		}
	    }
	} // else
	if (reject)
	    goto L30;

    } // end{branch III}


L_finis:  /* return appropriate variate */

#ifdef DEBUG_rhyper
    REprintf(" L_finis: ix = %d, then", ix);
#endif
    if ((double)kk + kk >= N) {
	if (nn1 > nn2) {
	    ix = kk - nn2 + ix;
	} else {
	    ix = nn1 - ix;
	}
    } else if (nn1 > nn2) {
	ix = kk - ix;
    }
#ifdef DEBUG_rhyper
    REprintf(" %d\n", ix);
#endif
    return ix;
}
