/*
 * unary.cpp: Simple unary operations on continued fractions.
 */

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

#include "spigot.h"
#include "funcs.h"
#include "cr.h"
#include "error.h"

class Abs : public BinaryIntervalSource {
    Spigot *x_orig;
    BracketingGenerator bg_x;

  public:
    Abs(Spigot *x)
        : x_orig(x->clone()), bg_x(x)
    {
        dprint("hello Abs %p", x);
    }

    virtual ~Abs()
    {
        delete x_orig;
    }

    virtual Abs *clone()
    {
        return new Abs(x_orig->clone());
    }

    virtual void gen_bin_interval(bigint *ret_lo, bigint *ret_hi,
                                  unsigned *ret_bits)
    {
        bg_x.get_bracket_shift(ret_lo, ret_hi, ret_bits);
        dprint("got x bracket (%b,%b) / 2^%d", ret_lo, ret_hi, (int)*ret_bits);

        /*
         * Adjust the interval we got to ensure it's all positive.
         */
        if (*ret_hi < 0) {
            /*
             * The whole interval was negative, so just reflect it.
             */
            bigint tmp = -*ret_hi;
            *ret_hi = -*ret_lo;
            *ret_lo = tmp;
        } else if (*ret_lo < 0) {
            /*
             * The interval crosses zero, so replace it with an
             * interval with one end at zero and the other end at the
             * maximum extent.
             */
            bigint tmp = -*ret_lo;
            if (*ret_hi < tmp)
                *ret_hi = tmp;
            *ret_lo = 0;
        }
    }
};

class Prepend : public Source {
    /*
     * This class prepends one extra spigot matrix to the stream
     * provided by another spigot, which permits us to apply any
     * Mobius transformation.
     */
    Source *x;
    bigint a, b, c, d;
    bool x_force;
    int crState;

  public:
    Prepend(Spigot *ax, bigint aa, bigint ab, bigint ac, bigint ad)
        : a(aa), b(ab), c(ac), d(ad)
    {
        x = ax->toSource();
        crState = -1;
    }

    ~Prepend()
    {
        delete x;
    }

    virtual Prepend *clone() { return new Prepend(x->clone(), a, b, c, d); }

    virtual bool gen_interval(bigint *low, bigint *high)
    {
        x_force = x->gen_interval(low, high);
        return true; /* force the absorption of our prefix matrix */
    }

    virtual bool gen_matrix(bigint *matrix)
    {
        crBegin;

        matrix[0] = a;
        matrix[1] = b;
        matrix[2] = c;
        matrix[3] = d;
        crReturn(x_force);

        while (1) {
            crReturn(x->gen_matrix(matrix));
        }

        crEnd;
    }
};

Spigot *spigot_reciprocal(Spigot *a)
{
    bigint n, d;
    if (a->is_rational(&n, &d)) {
        if (n == 0) {
            throw spigot_error("reciprocal of zero");
        }
        return spigot_rational(d, n);
    } else {
        return new Prepend(a, 0, 1, 1, 0);
    }
}

Spigot *spigot_rational_mul(Spigot *a, const bigint &n, const bigint &d)
{
    bigint an, ad;
    if (a->is_rational(&an, &ad)) {
        return spigot_rational(an * n, ad * d);
    } else {
        return new Prepend(a, n, 0, 0, d);
    }
}

Spigot *spigot_mobius(Spigot *x,
                      const bigint &a, const bigint &b,
                      const bigint &c, const bigint &d)
{
    bigint an, ad;
    if (x->is_rational(&an, &ad)) {
        return spigot_rational(a*an + b*ad, c*an + d*ad);
    } else {
        return new Prepend(x, a, b, c, d);
    }
}

Spigot *spigot_frac(Spigot *a)
{
    /*
     * To get the fractional part of a number, we do the completely
     * obvious thing of finding its floor and subtracting that off.
     * This has an obvious exactness hazard, but not at any point of
     * continuity, so it can't be helped - we really do need to know
     * which side of the boundary we're on.
     */
    StaticGenerator sg(a->clone());
    bool constant;
    bigint intpart = sg.get_floor(&constant);
    if (constant)
        return spigot_integer(0);
    else
        return spigot_mobius(a, 1, -intpart, 0, 1);
}

static Spigot *spigot_ieee_rem_1(Spigot *a)
{
    /*
     * Take the IEEE 754-style remainder of a with 1. That is, find
     * the nearest integer n to a, and return a-n; if a is equidistant
     * between two integers, take n to be the even one.
     *
     * We do this by the obvious approach of finding the integer part
     * of a+1/2. As in spigot_frac above, exactness hazards in this
     * are unavoidable.
     */
    StaticGenerator sg(spigot_add(spigot_rational(1, 2), a->clone()));
    bool constant;
    bigint intpart = sg.get_floor(&constant);
    if (constant) {
        /*
         * a+1/2 is _exactly_ equal to intpart, so this is the
         * half-way case, and we want either a-intpart or
         * a-(intpart-1).
         *
         * If intpart is even, then we want a-intpart = a-(a+1/2) =
         * -1/2; otherwise we want a-(intpart-1) = a-(a+1/2-1) = +1/2.
         */
        if (intpart % 2U == 0)
            return spigot_rational(-1, 2);
        else
            return spigot_rational(1, 2);
    } else {
        return spigot_mobius(a, 1, -intpart, 0, 1);
    }
}

Spigot *spigot_mod(Spigot *a, Spigot *b)
{
    return spigot_mul(b, spigot_frac(spigot_div(a, b->clone())));
}

Spigot *spigot_rem(Spigot *a, Spigot *b)
{
    return spigot_mul(b, spigot_ieee_rem_1(spigot_div(a, b->clone())));
}

Spigot *spigot_neg(Spigot *a)
{
    return spigot_rational_mul(a, -1, 1);
}

Spigot *spigot_abs(Spigot *a)
{
    bigint n, d;
    if (a->is_rational(&n, &d))
        return spigot_rational(bigint_abs(n), bigint_abs(d));

    return new Abs(a);
}

Spigot *spigot_floor(Spigot *a)
{
    CfracGenerator cfg(a);
    bigint i;
    cfg.get_term(&i);
    return spigot_integer(i);
}

Spigot *spigot_ceil(Spigot *a)
{
    CfracGenerator cfg(spigot_neg(a));
    bigint i;
    cfg.get_term(&i);
    return spigot_integer(-i);
}
