/******************************************************************************
* File: Strassen.hh
* Author: Keith Schwarz (htiek@cs.stanford.edu)
*
* An implementation of Strassen's algorithm for fast matrix multiplication.
* While a naive matrix multiply takes O(N^3) to multiply square matrices of
* size N x N, Strassen's algorithm runs in time O(N^(lg 7)) ~= O(N^2.8). It
* manages to do so by cleverly computing a family of smaller matrix products
* that can be used to recover all the multiplications necessary to compute the
* matrix product.
*
* To see how the algorithm works, let's begin by assuming that we're dealing
* with matrices that are of size 2^k x 2^k for some k >= 1. Then if we have
* the product C = AB, we can rewrite this as
*
* | C_00 C_01 | | A_00 A_01 | | B_00 B_01 |
* | C_10 C_11 | = | A_10 A_11 | | B_10 B_11 |
*
* And therefore
*
* C_00 = A_00 B_00 + A_01 B_10
* C_01 = A_00 B_01 + A_01 B_11
* C_10 = A_10 B_00 + A_11 B_10
* C_11 = A_10 B_01 + A_11 B_11
*
* This involves making eight multiplications of matrices of size 2^(k-1) and
* doing a total of four additions. We can then get a runtime of
*
* T(n) <= 8T(n / 2) + O(n^2)
*
* Using the Master Theorem, this recurrence expands out to O(n^(lg 8)) =
* O(n^3).
*
* Strassen's key insight was twofold. First, if we do use a divide-and-
* conquer algorithm to compute matrix products, we can perform a constant
* number of matrix additions at each step without affecting the asymptotic
* runtime of the algorithm. If you'll notice, in the above recurrence for
* T(N), as long as the amount of time spent doing additions is O(N^2), we'd
* arrive at the same recurrence. The second observation is that the eight
* matrix products we need to compute to find the C_ij's can be computed
* indirectly by instead computing the products of seven other matrices that
* individually do not give the answer, but whose sums or differences do. In
* this way, it's possible to eliminate an unnecessary multiplication by
* instead doing many more matrix additions and subtractions.
*
* More formally, Strassen's algorithm works by computing the following seven
* matrix products:
*
* M_0 = (A_00 + A_01) B_11
* M_1 = A_00 (B_01 - B_11)
* M_2 = (A_10 + A_11) B_00
* M_3 = A_11 (B_10 - B_00)
* M_4 = (A_10 - A_00) (B_00 + B_01)
* M_5 = (A_00 + A_11) (B_00 + B_11)
* M_6 = (A_01 - A_11) (B_10 + B_11)
*
* Notice that
*
* M_5 + M_6 - M_0 + M_3 = +1 (A_00 + A_11) (B_00 + B_11)
* +1 (A_01 - A_11) (B_10 + B_11)
* -1 (A_00 + A_01) B_11
* +1 (B_10 - B_00) A_11
* = +1 (A_00 B_00 + A_00 B_11 + A_11 B_00 + A_11 B_11)
* +1 (A_01 B_10 + A_01 B_11 - A_11 B_10 - A_11 B_11)
* -1 (A_00 B_11 + A_01 B_11)
* +1 (A_11 B_10 - A_11 B_00)
* = (A_00 B_00 + A_00 B_11 + A_11 B_00 + A_11 B_11)
* (A_01 B_10 + A_01 B_11 - A_11 B_10 - A_11 B_11)
* (-A_00 B_11 + -A_01 B_11)
* (A_11 B_10 - A_11 B_00)
* = A_00 B_00 + A_01 B_10
* = C_00
*
* M_0 + M_1 = A_00 B_11 + A_01 B_11 + A_00 B_01 - A_00 B_11
* = A_00 B_01 + A_01 B_11
* = C_01
*
* M_2 + M_3 = A_10 B_00 + A_11 B_00 + A_11 B_10 - A_11 B_00
* = A_10 B_00 + A_11 B_10
* = C_10
*
* M_4 + M_5 + M_1 - M_2 = +1 (A_10 - A_00) (B_00 + B_01)
* +1 (A_00 + A_11) (B_00 + B_11)
* +1 (B_01 - B_11) A_00
* -1 (A_10 + A_11) B_00
* = +1 (A_10 B_00 + A_10 B_01 - A_00 B_00 - A_00 B_01)
* +1 (A_00 B_00 + A_00 B_11 + A_11 B_00 + A_11 B_11)
* +1 (A_00 B_01 - A_00 B_11)
* -1 (A_10 B_00 + A_11 B_00)
* = (A_10 B_00 + A_10 B_01 - A_00 B_00 - A_00 B_01)
* (A_00 B_00 + A_00 B_11 + A_11 B_00 + A_11 B_11)
* (A_00 B_01 - A_00 B_11)
* (-A_10 B_00 - A_11 B_00)
* = A_10 B_01 + A_11 B_11
* = C_11
*
* In other words, these seven matrices are sufficient to recover all the C_ij
* and thus the entire matrix product.
*
* There are two cases we still need to address. First, this recursion needs
* a base case, since otherwise we'll just keep splitting indefinitely. This
* part is easy. One option would be to stop as soon as the submatrices hit
* size 1x1, in which case the multiplications are just straight scalar
* multiplications. A more efficient solution, and the one adopted in this
* implementation, is to pick some cutoff size (say, n = 4) and then to just
* do a straight matrix multiply whenever the matrices become this size or
* smaller.
*
* The other issue to consider is how to handle the case where the matrices
* aren't perfectly square matrices of size 2^k for some k. We can handle
* this case by padding each matrix until it does meet this requirement, then
* doing the multiplication, and finally extracting the resulting matrix from
* the padded product. If the input matrices are of size M x P and P x N,
* this approach takes time O(max{M, N, P}^{lg 7}).
*
* This code relies on the Matrix class provided by the Archive of Interesting
* Code. You can find it at
*
* http://www.keithschwarz.com/interesting/code/?dir=matrix
*/
#ifndef Strassen_Included
#define Strassen_Included
#include "Matrix.hh"
#include <algorithm> // For std::fill
/**
* Matrix<M, N> StrassenProduct(const Matrix<M, P>& lhs, const Matrix<P, N>& rhs);
* Usage: AB = StrassenProduct(A, B);
* ---------------------------------------------------------------------------
* Computes the product of two matrices uses Strassen's algorithm.
*/
template <size_t M, size_t N, size_t P, typename T>
const Matrix<M, N, T> StrassenProduct(const Matrix<M, P, T>& lhs,
const Matrix<P, N, T>& rhs);
/* * * * * Implementation Below This Point * * * * */
namespace strassen_detail {
/* Function: SquareStrassenProduct(const Matrix<N, N, T>& lhs,
* const Matrix<N, N, T>& rhs);
* Usage: AB = SquareStrassenProduct(A, B);
* -------------------------------------------------------------------------
* Computes the product of two square matrices (whose sizes are assumed to
* be perfect powers of two) using Strassen's algorithm.
*/
template <size_t N, typename T>
const Matrix<N, N, T> SquareStrassenProduct(const Matrix<N, N, T>& lhs,
const Matrix<N, N, T>& rhs) {
/* Base case: If the matrices are sufficiently small, just return their
* product using the naive algorithm.
*/
if (N <= 1)
return lhs * rhs;
/* Otherwise, extract four square submatrices from each of the input
* matrices. These are matrices of dimension N/2 by N/2.
*/
Matrix<N/2, N/2, T> A[2][2], B[2][2];
/* Fill in these matrices with the proper values. These outer loops count
* across the quadrants to scan, while the inner loops scan across the
* contents of those quadrants.
*/
for (size_t i = 0; i < 2; ++i) {
for (size_t j = 0; j < 2; ++j) {
for (size_t x = 0; x < N/2; ++x) {
for (size_t y = 0; y < N/2; ++y) {
/* Copy the contents of lhs into the corresponding A block. I've
* used the .at() function to fill in the matrix instead of using
* the square brackets to make it clearer that A[i][j] is the
* submatrix in question and (x, y) is the index.
*/
A[i][j].at(x, y) = lhs[i * N/2 + x][j * N/2 + y];
/* Similar logic for the rhs matrix. */
B[i][j].at(x, y) = rhs[i * N/2 + x][j * N/2 + y];
}
}
}
}
/* Compute the seven products necessary for the result. */
const Matrix<N/2, N/2, T>
M0 = SquareStrassenProduct(A[0][0] + A[0][1], B[1][1]),
M1 = SquareStrassenProduct(A[0][0], B[0][1] - B[1][1]),
M2 = SquareStrassenProduct(A[1][0] + A[1][1], B[0][0]),
M3 = SquareStrassenProduct(A[1][1], B[1][0] - B[0][0]),
M4 = SquareStrassenProduct(A[1][0] - A[0][0], B[0][0] + B[0][1]),
M5 = SquareStrassenProduct(A[0][0] + A[1][1], B[0][0] + B[1][1]),
M6 = SquareStrassenProduct(A[0][1] - A[1][1], B[1][0] + B[1][1]);
/* From this, compute the C_ij matrices corresponding to the components of
* the result.
*/
const Matrix<N/2, N/2, T> C[2][2] = { { M5 + M6 - M0 + M3, M0 + M1},
{ M2 + M3, M4 + M5 + M1 - M2} };
/* Finally, compose this back up into one large result matrix. */
Matrix<N, N, T> result;
/* This for loop is essentially the inverse of the loop for splitting the
* matrices into components.
*/
for (size_t i = 0; i < 2; ++i)
for (size_t j = 0; j < 2; ++j)
for (size_t x = 0; x < N/2; ++x)
for (size_t y = 0; y < N/2; ++y)
result[i * N/2 + x][j * N/2 + y] = C[i][j].at(x, y);
return result;
}
/* Metafunction: Log2<N>
* Usage: Log2<N>::value
* -------------------------------------------------------------------------
* A metafunction that computes floor(lg N).
*/
template <size_t N> struct Log2;
/* floor(lg 1) == 0 */
template <> struct Log2<1> {
static const size_t value = 0;
};
/* floor(lg N) = 1 + floor(lg(N / 2)) */
template <size_t N> struct Log2 {
static const size_t value = 1 + Log2<N/2>::value;
};
/* Metafunction: MatrixSize<N>
* Usage: MatrixSize<N>::value
* -------------------------------------------------------------------------
* A metafunction which, given as input N, the length of a side of a matrix,
* returns the side length of a matrix that should be used when the matrix
* is input into Strassen's algorithm. For powers of two, this is the size
* of the input, and for other values is the smallest power of two greater
* than the input size.
*/
template <size_t N> struct MatrixSize;
/* If the input matrix has size 0, the input matrix should have size
* zero.
*/
template <> struct MatrixSize<0> {
static const size_t value = 0;
};
/* Otherwise, compute 2^floor(lg N). If this equals the input size, then
* just hand that back. Otherwise, return twice that value.
*/
template <size_t N> struct MatrixSize {
static const size_t value =
(1 << Log2<N>::value == N)? N : (1 << (Log2<N>::value + 1));
};
/* Metafunction: Max<M, N>
* Usage: Max<137, 42>::value
* -------------------------------------------------------------------------
* Returns the larger of the two inputs.
*/
template <size_t M, size_t N> struct Max {
static const size_t value = M >= N? M : N;
};
}
/* Actual implementation of the Strassen product. */
template <size_t M, size_t N, size_t P, typename T>
const Matrix<M, N, T> StrassenProduct(const Matrix<M, P, T>& lhs,
const Matrix<P, N, T>& rhs) {
/* Give access to all of the utility functions and templates from above. */
using namespace strassen_detail;
/* Begin by finding the smallest square matrix that can hold both matrices
* while having power-of-two side lengths.
*/
static const size_t MaxDim = Max<Max<M, N>::value, P>::value;
static const size_t Size = MatrixSize<MaxDim>::value;
typedef Matrix<Size, Size, T> InputMatrix;
/* Create two matrices of this size, filling them with zeros. */
InputMatrix newLhs, newRhs;
std::fill(newLhs.begin(), newLhs.end(), T(0));
std::fill(newRhs.begin(), newRhs.end(), T(0));
/* Copy over the matrix contents. */
for (size_t i = 0; i < M; ++i)
for (size_t j = 0; j < P; ++j)
newLhs[i][j] = lhs[i][j];
for (size_t i = 0; i < P; ++i)
for (size_t j = 0; j < N; ++j)
newRhs[i][j] = rhs[i][j];
/* Fire off a call to Strassen's algorithm in the square case. */
const InputMatrix product = SquareStrassenProduct(newLhs, newRhs);
/* Recover the resulting matrix as a submatrix. */
Matrix<M, N, T> result;
for (size_t i = 0; i < M; ++i)
for (size_t j = 0; j < N; ++j)
result[i][j] = product[i][j];
/* Hand back this submatrix. */
return result;
}
#endif