/***************************************************************
 * File: Recurrence.hh
 * Author: Keith Schwarz (htiek@cs.stanford.edu)
 *
 * Implementation of an algorithm for efficiently producing the
 * values generated by a linear recurrence.  A linear recurrence
 * is a series of values defined recursively.  For some k, let
 *
 * x_0 = a_0, x_1 = a_1, ..., x_k = a_k
 *
 * Then, let
 *
 * x_{n+k+1} = c_0 x_{n} + c_1 x_{n+1} + ... + c_k x_{k}
 *
 * The most famous example of a linear recurrence relation is the
 * Fibonacci series, where
 *
 * x_0 = 0, x_1 = 1, x_{n+2} = x_n + x_{n+1}
 *
 * The first few terms of this series are
 *
 * 0, 1, 1, 2, 3, 5, 8, 13, 21, ...
 *
 * The "Tribonacci" series is similarly defined as
 *
 * x_0 = 0, x_1 = 0, x_1 = 1, x_{n+3} = x_n + x_{n+1} + x_{n+2}
 *
 * Whose terms are
 *
 * 0, 0, 1, 1, 2, 4, 7, 13, 24, ...
 *
 * Both the Fibonacci and "Tribonacci" sequences have unity as their
 * coefficients, but it's possible to build series where this is not
 * the case.  For example, consider this series:
 *
 * x_0 = 0, x_1 = 1, x_{n+2} = 2x_n + x_{n+1}
 *
 * Its terms are
 *
 * 0, 1, 1, 3, 5, 11, 21, 43, ...
 *
 * Given a linear recurrence, we say that the "degree" of this recurrence
 * is the number of terms upon which the next term depends.  For Fibonacci
 * numbers, this is two, while for "Tribonacci" numbers it's three.  This
 * implementation can solve linear recurrences in time O(k^3 lg n).  In
 * particular, note that for a fixed k this algorithm is O(lg n), so it
 * can be used to compute arbitrary terms of any fixed recurrence relation
 * in logarithmic time.
 *
 * The key trick behind this algorithm (due to a 1966 paper by Miller and
 * Brown) is to recast this problem in terms of matrix multiplications.
 * Given a recurrence relation
 *
 * x_0 = a_0, x_1 = a_1, ..., x_k = a_k
 * x_{n+k+1} = c_0 x_{n} + c_1 x_{n} + ... + c_k x_{n+k}
 *
 * Consider the matrix
 *
 * | 0   1   0       0 |
 * | 0   0   1  ...  0 |
 * | 0   0   0       0 |
 * |     .       .   . | = C
 * |     .        .  . |
 * |     .         . . |
 * | 0   0   0       1 |
 * |c_0 c_1 c_2 ... c_k|
 *
 * Then note that
 *
 * | 0   1   0       0 | |x_n    |   |              x_{n+1}                    |
 * | 0   0   1  ...  0 | |x_{n+1}|   |              x_{n+2}                    |
 * | 0   0   0       0 | |       |   |                 .                       |
 * |     .       .   . | | .     | = |                 .                       |
 * |     .        .  . | | .     |   |                 .                       |
 * |     .         . . | | .     |   |              x_{n+k-1}                  |
 * | 0   0   0       1 | |       |   |              x_{n+k}                    |
 * |c_0 c_1 c_2 ... c_k| |x_{n+k}|   |c_0 x_n + c_1 x_{n+1} + ... + c_k x_{n+k}|
 *
 *                                   | x_{n+1} |
 *                                   | x_{n+2} |
 *                                   |    .    |
 *                                 = |    .    |
 *                                   |    .    |
 *                                   | x_{n+k} |
 *                                   |x_{n+k+1}|
 *
 * Moreover, this process can be iterated.  Let X = (x_n, x_{n+1}, ..., x_{n+k}).
 * Then (C^m)X = (x_{n+m}, x_{n+m+1}, ..., x_{n+m+k}).  In particular, for any
 * n > k, we have that x_n is the kth component of (C^{n-k})(a_0, a_1, ..., a_k).
 * This gives us a fast algorithm for computing linear recurrences.  If we can raise
 * C to the (n-k)th power, we can then multiply the vector (a_0, a_1, ..., a_k) by
 * that matrix, then take the last component.
 *
 * Fortunately, we can compute C^n efficiently (in O(lg n) time) for any n using
 * the fact that
 *
 * C^0        = I
 * C^1        = C
 * C^{2n}     = C^n * C^n
 * C^{2n + 1} = C^n * C^n * C
 *
 * Each of these matrix multiplications takes O(k^3) time using a naive algorithm,
 * giving us the final runtime of O(k^3 lg n).
 *
 * This code relies on the Matrix.hh header file also contained in the Archive of
 * Interesting Code.  You can find it at
 *
 *           http://www.keithschwarz.com/interesting/code/?dir=matrix
 */


#ifndef Recurrence_Included
#define Recurrence_Included

#include "Matrix.hh" // For Matrix, Vector
#include <algorithm> // For copy, fill

/**
 * Function: RecurrenceValue(const Vector<T, K>& initialValues,
 *                           const Vector<T, K>& coefficients,
 *                           size_t n);
 * Usage: cout << RecurrenceValue(fibInitials, fibCoeffs, 30) << endl;
 * ----------------------------------------------------------------------------
 * Given a recurrence relation whose first terms are iV[0], iV[1], ..., iV[k-1]
 * recursively defined as
 *
 * X_{n+k} = c[0] X_{n} + c[1] X_{n+1} + ... + c[k-1] X_{n+k-1}
 *
 * Returns the nth term of the series.
 */

template <size_t K, typename T>
T RecurrenceValue(const Vector<K, T>& initialValues,
                  const Vector<K, T>& coefficients,
                  size_t n);

/* * * * * Implementation Below This Point * * * * */
namespace recurrence_detail {
  /* Utility function to raise a matrix to a given power. */
  template <size_t K, typename T>
  Matrix<K, K, T> MatrixPower(const Matrix<K, K, T>& matrix, size_t n) {
    /* Base cases: the zeroth or first power of a matrix are immediate. */
    if (n == 0return Identity<K, T>();
    if (n == 1return matrix;
    
    /* Otherwise, compute MatrixPower(floor(n / 2)).  Both the odd and
     * even cases use this fact.
     */

    Matrix<K, K, T> sqrtPower = MatrixPower(matrix, n / 2);
    
    /* If n is even, then the result is just the above matrix squared:
     *
     * (M ^ (n/2)) ^ 2 = M ^ (2n)
     */

    if (n % 2 == 0)
      return sqrtPower * sqrtPower;
    
    /* Otherwise, n is odd, and the result is the square of this matrix,
     * multiplied by the matrix itself:
     *
     * ((M ^ (n/2)) ^ 2) * M = M ^ (2n + 1)
     */

    return sqrtPower * sqrtPower * matrix;
  }
  
  /* Utility function to compute the matrix C (as described by the file 
   * comments), which computes another iteration of the recurrence
   * relation.
   */

  template <size_t K, typename T>
  Matrix<K, K, T> RecurrenceTransformMatrix(const Vector<K, T>& coefficients) {
    /* As a reminder, the matrix we want is defined here:
     *
     * | 0   1   0       0 |
     * | 0   0   1  ...  0 |
     * | 0   0   0  .    0 |
     * |     .       .   . | = C
     * |     .        .  . |
     * |     .           . |
     * | 0   0   0       1 |
     * |c_0 c_1 c_2 ... c_k|
     */

    Matrix<K, K, T> C;
    std::fill(C.begin(), C.end(), T(0));

    /* Copy the coefficients into the bottom row. */
    std::copy(coefficients.begin(), coefficients.end(),
        C.row_begin(K - 1));
  
    /* Write the diagonal elements in. */
    for (size_t i = 0; i < K - 1; ++i)
      C[i][i + 1] = T(1);

    return C;
  }
}

/* Implementation of RecurrenceValue. */
template <size_t K, typename T>
T RecurrenceValue(const Vector<K, T>& initialValues,
                  const Vector<K, T>& coefficients,
                  size_t n) {
  using namespace recurrence_detail;
  
  /* If the index we want is in the range [0, k - 1], then we already have the
   * answer.
   */

  if (n < K)
    return initialValues[n];

  /* Otherwise, we'll compute the transform matrix and raise it to the power
   * of (n - k) + 1.  This matrix, when multiplied by the initial values, produces 
   * a vector whose last element is the nth term of the recurrence.
   */

  Matrix<K, K, T> C = MatrixPower(RecurrenceTransformMatrix(coefficients), n - K + 1);

  /* We could just directly multiply the initial values by C to get the final value,
   * but that would be wasteful.  We just need the value in the final position.  To
   * do this, we multiply the last column of the matrix by the initial values, which
   * produces the final element of the vector.
   */

  T result = T(0);
  for (size_t i = 0; i < K; ++i)
    result += C[i][K - 1] * initialValues[i];

  return result;
}

#endif