/*
    Copyright (C) 2024 Fredrik Johansson

    This file is part of FLINT.

    FLINT is free software: you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.  See <https://www.gnu.org/licenses/>.
*/

#include <stdint.h>
#include "mpn_mod.h"

/* Tuning tables generated by p-poly_mullow. */

/* The first entry is the cutoff for classical -> Karatsuba,
   the second for Karatsuba -> KS/fft_small. We make the simplifying
   assumption that fft_small always beats KS when it can be used;
   this is not always the case, the close enough. */

static const uint8_t mul_cutoffs[][2] = {
  {37, 91},   /* bits = 80 */
  {32, 93},   /* bits = 96 */
  {31, 93},   /* bits = 112 */
  {80, 90},   /* bits = 128 */
  {24, 77},   /* bits = 144 */
  {15, 73},   /* bits = 160 */
  {16, 113},   /* bits = 176 */
  {20, 82},   /* bits = 192 */
  {14, 82},   /* bits = 208 */
  {14, 109},   /* bits = 224 */
  {14, 106},   /* bits = 240 */
  {15, 95},   /* bits = 256 */
  {10, 90},   /* bits = 272 */
  {10, 90},   /* bits = 288 */
  {13, 90},   /* bits = 304 */
  {14, 81},   /* bits = 320 */
  {11, 82},   /* bits = 336 */
  {10, 81},   /* bits = 352 */
  {10, 97},   /* bits = 368 */
  {12, 72},   /* bits = 384 */
  {10, 69},   /* bits = 400 */
  {9, 70},   /* bits = 416 */
  {10, 83},   /* bits = 432 */
  {10, 73},   /* bits = 448 */
  {8, 69},   /* bits = 464 */
  {8, 71},   /* bits = 480 */
  {9, 73},   /* bits = 496 */
  {10, 54},   /* bits = 512 */
  {8, 53},   /* bits = 528 */
  {8, 65},   /* bits = 544 */
  {8, 65},   /* bits = 560 */
  {8, 51},   /* bits = 576 */
  {6, 57},   /* bits = 592 */
  {6, 67},   /* bits = 608 */
  {6, 67},   /* bits = 624 */
  {8, 45},   /* bits = 640 */
  {6, 50},   /* bits = 656 */
  {6, 53},   /* bits = 672 */
  {6, 53},   /* bits = 688 */
  {8, 40},   /* bits = 704 */
  {6, 39},   /* bits = 720 */
  {6, 49},   /* bits = 736 */
  {6, 65},   /* bits = 752 */
  {6, 36},   /* bits = 768 */
  {6, 36},   /* bits = 784 */
  {6, 41},   /* bits = 800 */
  {6, 41},   /* bits = 816 */
  {6, 34},   /* bits = 832 */
  {6, 34},   /* bits = 848 */
  {4, 38},   /* bits = 864 */
  {6, 37},   /* bits = 880 */
  {6, 36},   /* bits = 896 */
  {4, 36},   /* bits = 912 */
  {4, 37},   /* bits = 928 */
  {4, 37},   /* bits = 944 */
  {6, 34},   /* bits = 960 */
  {4, 34},   /* bits = 976 */
  {4, 34},   /* bits = 992 */
  {4, 35},   /* bits = 1008 */
  {6, 25},   /* bits = 1024 */
};

static const uint8_t mul_unbalanced_cutoffs[][2] = {
  {54, 55},   /* bits = 80 */
  {44, 46},   /* bits = 96 */
  {42, 53},   /* bits = 112 */
  {63, 63},   /* bits = 128 */
  {34, 45},   /* bits = 144 */
  {18, 39},   /* bits = 160 */
  {12, 50},   /* bits = 176 */
  {27, 42},   /* bits = 192 */
  {13, 29},   /* bits = 208 */
  {12, 37},   /* bits = 224 */
  {12, 40},   /* bits = 240 */
  {34, 34},   /* bits = 256 */
  {12, 33},   /* bits = 272 */
  {11, 43},   /* bits = 288 */
  {12, 46},   /* bits = 304 */
  {39, 39},   /* bits = 320 */
  {12, 40},   /* bits = 336 */
  {11, 42},   /* bits = 352 */
  {11, 44},   /* bits = 368 */
  {36, 36},   /* bits = 384 */
  {8, 35},   /* bits = 400 */
  {8, 40},   /* bits = 416 */
  {9, 38},   /* bits = 432 */
  {21, 37},   /* bits = 448 */
  {8, 35},   /* bits = 464 */
  {8, 34},   /* bits = 480 */
  {8, 33},   /* bits = 496 */
  {23, 32},   /* bits = 512 */
  {7, 31},   /* bits = 528 */
  {7, 31},   /* bits = 544 */
  {7, 30},   /* bits = 560 */
  {22, 28},   /* bits = 576 */
  {6, 28},   /* bits = 592 */
  {7, 27},   /* bits = 608 */
  {6, 27},   /* bits = 624 */
  {15, 26},   /* bits = 640 */
  {7, 25},   /* bits = 656 */
  {6, 25},   /* bits = 672 */
  {6, 24},   /* bits = 688 */
  {11, 24},   /* bits = 704 */
  {5, 23},   /* bits = 720 */
  {5, 23},   /* bits = 736 */
  {5, 22},   /* bits = 752 */
  {11, 22},   /* bits = 768 */
  {5, 21},   /* bits = 784 */
  {4, 21},   /* bits = 800 */
  {5, 21},   /* bits = 816 */
  {7, 20},   /* bits = 832 */
  {6, 20},   /* bits = 848 */
  {4, 19},   /* bits = 864 */
  {5, 19},   /* bits = 880 */
  {6, 19},   /* bits = 896 */
  {5, 18},   /* bits = 912 */
  {4, 18},   /* bits = 928 */
  {4, 18},   /* bits = 944 */
  {6, 18},   /* bits = 960 */
  {5, 17},   /* bits = 976 */
  {5, 17},   /* bits = 992 */
  {4, 17},   /* bits = 1008 */
  {16, 16},   /* bits = 1024 */
};

static const uint8_t sqr_cutoffs[][2] = {
  {34, 96},   /* bits = 80 */
  {34, 93},   /* bits = 96 */
  {43, 93},   /* bits = 112 */
  {98, 98},   /* bits = 128 */
  {40, 78},   /* bits = 144 */
  {24, 76},   /* bits = 160 */
  {24, 121},   /* bits = 176 */
  {34, 90},   /* bits = 192 */
  {20, 100},   /* bits = 208 */
  {18, 119},   /* bits = 224 */
  {22, 117},   /* bits = 240 */
  {32, 100},   /* bits = 256 */
  {18, 101},   /* bits = 272 */
  {16, 99},   /* bits = 288 */
  {20, 117},   /* bits = 304 */
  {20, 91},   /* bits = 320 */
  {14, 99},   /* bits = 336 */
  {14, 103},   /* bits = 352 */
  {18, 133},   /* bits = 368 */
  {20, 106},   /* bits = 384 */
  {18, 88},   /* bits = 400 */
  {14, 118},   /* bits = 416 */
  {18, 113},   /* bits = 432 */
  {18, 109},   /* bits = 448 */
  {16, 90},   /* bits = 464 */
  {15, 102},   /* bits = 480 */
  {16, 99},   /* bits = 496 */
  {16, 69},   /* bits = 512 */
  {12, 69},   /* bits = 528 */
  {8, 90},   /* bits = 544 */
  {10, 88},   /* bits = 560 */
  {10, 83},   /* bits = 576 */
  {10, 83},   /* bits = 592 */
  {12, 70},   /* bits = 608 */
  {11, 79},   /* bits = 624 */
  {16, 57},   /* bits = 640 */
  {12, 69},   /* bits = 656 */
  {11, 73},   /* bits = 672 */
  {8, 72},   /* bits = 688 */
  {12, 68},   /* bits = 704 */
  {8, 68},   /* bits = 720 */
  {8, 67},   /* bits = 736 */
  {7, 66},   /* bits = 752 */
  {11, 64},   /* bits = 768 */
  {8, 63},   /* bits = 784 */
  {6, 63},   /* bits = 800 */
  {6, 65},   /* bits = 816 */
  {8, 59},   /* bits = 832 */
  {6, 58},   /* bits = 848 */
  {6, 57},   /* bits = 864 */
  {7, 57},   /* bits = 880 */
  {8, 55},   /* bits = 896 */
  {6, 54},   /* bits = 912 */
  {6, 53},   /* bits = 928 */
  {6, 53},   /* bits = 944 */
  {8, 52},   /* bits = 960 */
  {6, 51},   /* bits = 976 */
  {6, 50},   /* bits = 992 */
  {6, 50},   /* bits = 1008 */
  {8, 34},   /* bits = 1024 */
};

static const uint8_t mullow_cutoffs[][2] = {
  {106, 106},   /* bits = 80 */
  {104, 104},   /* bits = 96 */
  {104, 104},   /* bits = 112 */
  {126, 126},   /* bits = 128 */
  {84, 84},   /* bits = 144 */
  {64, 69},   /* bits = 160 */
  {64, 88},   /* bits = 176 */
  {84, 84},   /* bits = 192 */
  {56, 86},   /* bits = 208 */
  {56, 109},   /* bits = 224 */
  {56, 113},   /* bits = 240 */
  {96, 96},   /* bits = 256 */
  {60, 90},   /* bits = 272 */
  {60, 95},   /* bits = 288 */
  {59, 90},   /* bits = 304 */
  {86, 86},   /* bits = 320 */
  {48, 82},   /* bits = 336 */
  {61, 82},   /* bits = 352 */
  {60, 97},   /* bits = 368 */
  {75, 75},   /* bits = 384 */
  {44, 69},   /* bits = 400 */
  {46, 70},   /* bits = 416 */
  {47, 83},   /* bits = 432 */
  {64, 72},   /* bits = 448 */
  {43, 69},   /* bits = 464 */
  {32, 71},   /* bits = 480 */
  {40, 73},   /* bits = 496 */
  {63, 65},   /* bits = 512 */
  {40, 54},   /* bits = 528 */
  {32, 65},   /* bits = 544 */
  {32, 65},   /* bits = 560 */
  {47, 58},   /* bits = 576 */
  {30, 57},   /* bits = 592 */
  {28, 67},   /* bits = 608 */
  {28, 67},   /* bits = 624 */
  {32, 52},   /* bits = 640 */
  {24, 52},   /* bits = 656 */
  {24, 54},   /* bits = 672 */
  {24, 67},   /* bits = 688 */
  {32, 49},   /* bits = 704 */
  {27, 41},   /* bits = 720 */
  {26, 49},   /* bits = 736 */
  {24, 66},   /* bits = 752 */
  {31, 42},   /* bits = 768 */
  {24, 37},   /* bits = 784 */
  {24, 43},   /* bits = 800 */
  {24, 51},   /* bits = 816 */
  {31, 39},   /* bits = 832 */
  {16, 38},   /* bits = 848 */
  {16, 38},   /* bits = 864 */
  {16, 39},   /* bits = 880 */
  {29, 36},   /* bits = 896 */
  {24, 36},   /* bits = 912 */
  {16, 37},   /* bits = 928 */
  {16, 37},   /* bits = 944 */
  {29, 35},   /* bits = 960 */
  {16, 35},   /* bits = 976 */
  {16, 35},   /* bits = 992 */
  {16, 35},   /* bits = 1008 */
  {32, 32},   /* bits = 1024 */
};

static const uint8_t sqrlow_cutoffs[][2] = {
  {122, 122},   /* bits = 80 */
  {110, 110},   /* bits = 96 */
  {114, 114},   /* bits = 112 */
  {173, 173},   /* bits = 128 */
  {98, 98},   /* bits = 144 */
  {81, 81},   /* bits = 160 */
  {102, 102},   /* bits = 176 */
  {97, 97},   /* bits = 192 */
  {124, 124},   /* bits = 208 */
  {122, 122},   /* bits = 224 */
  {120, 120},   /* bits = 240 */
  {135, 135},   /* bits = 256 */
  {105, 105},   /* bits = 272 */
  {111, 117},   /* bits = 288 */
  {125, 152},   /* bits = 304 */
  {146, 146},   /* bits = 320 */
  {104, 104},   /* bits = 336 */
  {111, 133},   /* bits = 352 */
  {109, 133},   /* bits = 368 */
  {127, 127},   /* bits = 384 */
  {95, 106},   /* bits = 400 */
  {93, 118},   /* bits = 416 */
  {93, 113},   /* bits = 432 */
  {111, 111},   /* bits = 448 */
  {96, 102},   /* bits = 464 */
  {80, 102},   /* bits = 480 */
  {80, 99},   /* bits = 496 */
  {96, 96},   /* bits = 512 */
  {90, 90},   /* bits = 528 */
  {80, 90},   /* bits = 544 */
  {79, 88},   /* bits = 560 */
  {85, 85},   /* bits = 576 */
  {77, 83},   /* bits = 592 */
  {64, 81},   /* bits = 608 */
  {63, 79},   /* bits = 624 */
  {77, 77},   /* bits = 640 */
  {63, 75},   /* bits = 656 */
  {48, 73},   /* bits = 672 */
  {48, 72},   /* bits = 688 */
  {70, 70},   /* bits = 704 */
  {60, 68},   /* bits = 720 */
  {48, 67},   /* bits = 736 */
  {60, 66},   /* bits = 752 */
  {64, 64},   /* bits = 768 */
  {46, 63},   /* bits = 784 */
  {46, 64},   /* bits = 800 */
  {47, 65},   /* bits = 816 */
  {59, 59},   /* bits = 832 */
  {46, 58},   /* bits = 848 */
  {44, 57},   /* bits = 864 */
  {46, 57},   /* bits = 880 */
  {55, 55},   /* bits = 896 */
  {46, 54},   /* bits = 912 */
  {45, 53},   /* bits = 928 */
  {45, 52},   /* bits = 944 */
  {48, 52},   /* bits = 960 */
  {32, 51},   /* bits = 976 */
  {44, 50},   /* bits = 992 */
  {44, 51},   /* bits = 1008 */
  {48, 48},   /* bits = 1024 */
};

int
_mpn_mod_poly_mullow(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong len, gr_ctx_t ctx)
{
    slong n;
    slong bits, cutoff_karatsuba, cutoff_fft_KS, tab_i;

    len1 = FLINT_MIN(len1, len);
    len2 = FLINT_MIN(len2, len);
    n = FLINT_MIN(len1, len2);

    if (n < 4)
        return _mpn_mod_poly_mullow_classical(res, poly1, len1, poly2, len2, len, ctx);

    bits = MPN_MOD_CTX_MODULUS_BITS(ctx);

    FLINT_ASSERT(bits > FLINT_BITS);

    tab_i = (bits - FLINT_BITS - 1) / 16;

    if (poly1 == poly2 && len1 == len2)
    {
        if (len == len1 + len2 - 1)
        {
            cutoff_karatsuba = sqr_cutoffs[tab_i][0];
            cutoff_fft_KS = sqr_cutoffs[tab_i][1];
        }
        else
        {
            cutoff_karatsuba = sqrlow_cutoffs[tab_i][0];
            cutoff_fft_KS = sqrlow_cutoffs[tab_i][1];
        }
    }
    else
    {
        if (FLINT_MAX(len1, len2) >= 2 * n)
        {
            cutoff_karatsuba = mul_unbalanced_cutoffs[tab_i][0];
            cutoff_fft_KS = mul_unbalanced_cutoffs[tab_i][1];
        }
        else if (len == len1 + len2 - 1)
        {
            cutoff_karatsuba = mul_cutoffs[tab_i][0];
            cutoff_fft_KS = mul_cutoffs[tab_i][1];
        }
        else
        {
            cutoff_karatsuba = mullow_cutoffs[tab_i][0];
            cutoff_fft_KS = mullow_cutoffs[tab_i][1];
        }
    }

    if (n < cutoff_karatsuba)
        return _mpn_mod_poly_mullow_classical(res, poly1, len1, poly2, len2, len, ctx);

    if (n < cutoff_fft_KS && FLINT_MAX(len1, len2) < 4 * n)
        return _mpn_mod_poly_mullow_karatsuba(res, poly1, len1, poly2, len2, len, -1, ctx);

    if (GR_SUCCESS == _mpn_mod_poly_mullow_fft_small(res, poly1, len1, poly2, len2, len, ctx))
        return GR_SUCCESS;

    return _mpn_mod_poly_mullow_KS(res, poly1, len1, poly2, len2, len, ctx);
}
