/******************************************************************************
 * File: BinarySubset.hh
 * Author: Keith Schwarz (htiek@cs.stanford.edu)
 *
 * A collection of functions for generating and manipulating subsets in
 * lexicographic order based on the bijective correspondence between binary 
 * numbers and subsets.  For a more detailed explanation of the algorithms
 * implemented here, refer to my writeup "Generating Subsets Lexicographically
 * with Binary Numbers and Cyclic Shifts," available online at
 *
 *                     www.keithschwarz.com/binary-subsets
 */

#ifndef BinarySubset_Included
#define BinarySubset_Included

/**
 * Function: NthSubset(ForwardIterator begin, ForwardIterator end,
 *                     OutputIterator result, Integer k);
 * Usage: NthSubset(v.begin(), v.end(), back_inserter(subset), 137u);
 * ----------------------------------------------------------------------------
 * Given the sorted range [begin, end) and an integer n, copies the nth
 * lexicographically-smallest subset of [begin, end) to the range starting at
 * result.  It is assumed that n is in range and that the integer type used to
 * hold n is large enough to hold 2^(end - begin).
 */

template <typename ForwardIterator, typename OutputIterator, typename Integer>
OutputIterator NthSubset(ForwardIterator begin, ForwardIterator end,
                         OutputIterator result, Integer k);

/**
 * Function: SubsetIndex(InputIterator subsetBegin, InputIterator subsetEnd,
 *                       ForwardIterator setBegin, ForwardIterator setEnd);
 * Usage: cout << SubsetIndex<size_t>(s.begin(), s.end(), v.begin(), v.end());
 * ----------------------------------------------------------------------------
 * Given a subset of [setBegin, setEnd) defined in [subsetBegin, subsetEnd),
 * returns the index of that subset in the sequence of lexicographically-
 * ordered subsets of the master set, assuming the elements of the master set
 * are stored in sorted order.  The resulting integral type is assumed to be
 * large enough to hold 2^(setEnd - setBegin)
 */

template <typename Integer, typename InputIterator, typename ForwardIterator>
Integer SubsetIndex(InputIterator subsetBegin, InputIterator subsetEnd,
                    ForwardIterator setBegin, ForwardIterator setEnd);

/**
 * Function: SubsetIndex(InputIterator subsetBegin, InputIterator subsetEnd,
 *                       ForwardIterator setBegin, ForwardIterator setEnd,
 *                       Comparator comp);
 * Usage: cout << SubsetIndex<size_t>(s.begin(), s.end(), v.begin(), v.end());
 * ----------------------------------------------------------------------------
 * Given a subset of [setBegin, setEnd) defined in [subsetBegin, subsetEnd),
 * returns the index of that subset in the sequence of lexicographically-
 * ordered subsets of the master set, assuming the elements of the master set
 * are stored in sorted order.  The resulting integral type is assumed to be
 * large enough to hold 2^(setEnd - setBegin).  Comparisons are done according
 * to the comparator comp.
 */

template <typename Integer, typename InputIterator, typename ForwardIterator,
          typename Comparator>
Integer SubsetIndex(InputIterator subsetBegin, InputIterator subsetEnd,
                    ForwardIterator setBegin, ForwardIterator setEnd,
                    Comparator comp);

/* * * * * Implementation Below This Point * * * * */
#include <iterator>   // For std::distance, std::iterator_traits
#include <functional> // For std::less
#include <cassert>

template <typename ForwardIterator, typename OutputIterator, typename Integer>
OutputIterator NthSubset(ForwardIterator begin, ForwardIterator end,
                         OutputIterator result, Integer k) {
  /* Begin by seeing how many elements are in the range [begin, end). */
  const Integer n(std::distance(begin, end));

  /* Now, we need to invert the set of shifts to get back the raw index of the
   * subset in the canonical ordering of binary numbers.  This works by looking
   * at the bits of the number and doing a reverse cyclic shift of the
   * appropriate size whenever we find a zero.
   *
   * Normally, I wouldn't write a loop that counts down an unsigned value to
   * zero, but it's okay here because we stop on the last bit, not after the
   * last bit.
   */

  for (Integer bitIndex(n); bitIndex > Integer(0); -- bitIndex) {
    /* See if the bit an index k is set.  If not, we need to flip everything
     * below this point.
     */

    if (((k >> bitIndex) & Integer(1)) == Integer(0)) {
      /* Do a cyclic shift backwards of the bits after the current bit.  We do
       * this by checking whether the bits after this one are identically zero.
       * If so, we overwrite them all with 1s.  Otherwise, we just subtract
       * one.
       *
       * To see what the remaining bits are, we compute the bitwise AND of
       * a list of ones that spans the bits after this one.  We get a list like
       * this by taking 2^k and subtracting one.
       */

      const Integer ones = (Integer(1) << bitIndex) - Integer(1);

      if (k & ones)
        -- k;
      else
        k |= ones;
    }
  }

  /* Now, we have converted k into a binary number from which we can read off
   * the elements of the set one at a time by picking elements where the bits
   * are zero and skipping elements where the bits are one.  For this loop,
   * because we do need to count all the way down to zero, we count up from 0
   * to n and flip the iteration counter to change [0, n) to (n, 0].
   *
   * On each iteration, we also increment the begin iterator so that we keep
   * track of "the current element."
   */

  for (Integer bitIndex(0); bitIndex < n; ++ bitIndex, ++ begin) {
    /* Flip the index around, avoiding underflow. */
    Integer realIndex = (n - Integer(1)) - bitIndex;

    /* See if that bit is set.  If not, pick the element. */
    if ((k & (Integer(1) << realIndex)) == Integer(0))
      *result++ = *begin;
  }

  return result;
}

/* To determine the index of a subset, we first build up the integer
 * corresponding to the bitmap of the subset, and then apply the appropriate
 * shifts to convert that number to its final position.
 */

template <typename Integer, typename InputIterator, typename ForwardIterator,
          typename Comparator>
Integer SubsetIndex(InputIterator subsetBegin, InputIterator subsetEnd,
                    ForwardIterator setBegin, ForwardIterator setEnd,
                    Comparator comp) {
  /* See how many elements there are in the range. */
  const Integer n(std::distance(setBegin, setEnd));

  /* Scan across the two ranges determining which elements are present in the
   * subset and which aren't.  As we're doing so, we'll set the appropriate
   * bits in the bitmask.
   *
   * For simplicity, we'll pessimistically assume that the set is empty and
   * that all the bits in the bitmap are one.  We'll then clear all the bits
   * where a match occurs.
   *
   * Since we have to count down and Integer may be unsigned, we'll invert the
   * loop counter.
   */

  Integer bitmap = (Integer(1) << n) - Integer(1);
  for (Integer bitIndex(0); setBegin != setEnd && subsetBegin != subsetEnd;
       ++ bitIndex) {
    /* See how the current elements of the sequences compare.  If the current
     * element of the sequences match, record a zero bit in the appropriate
     * index.
     */

    if (!comp(*setBegin, *subsetBegin) && !comp(*subsetBegin, *setBegin)) {
      const Integer realIndex = n - Integer(1) - bitIndex;
      bitmap &= ~(Integer(1) << realIndex);

      /* Advance both iterators forward, since we just consumed the element. */
      ++ setBegin;
      ++ subsetBegin;
    } else {
      /* We didn't consume the current element of the subset, so advance the
       * master set pointer forward so we can try again.
       */

      ++ setBegin;
    }
  }

  /* We should have consumed all the elements from the subset.  If not, then
   * the elements aren't drawn from the master set.
   */

  assert (subsetBegin == subsetEnd);

  /* Now, scan the bits from the least-to-most significant direction.  Whenever
   * we encounter a zero bit, perform a shift of the appropriate size.  Since a
   * shift of size one has no effect, as a microoptimization we'll start at
   * position one.  This also simplifies the logic for doing the shift.  Note
   * that we go one step past the end, because we need to use the zero bit
   * before the bitmask to do one final shift at the end.
   */

  for (Integer bitIndex(1); bitIndex <= n; ++ bitIndex) {
    /* If this bit is a zero, we need to cycle the rest of the bits. */
    if ((bitmap & (Integer(1) << bitIndex)) == Integer(0)) {
      /* We're doing a forward cycle of the rest of the bits, which maps any
       * bit pattern other than all ones to the current number plus one.  The
       * bit pattern of all ones gets mapped to all zeros.  We'll thus make an
       * integer with the appropriate pattern of ones and use a logical AND to
       * see which case we're in.
       */

      Integer ones = (Integer(1) << bitIndex) - Integer(1);
      if ((bitmap & ones) == ones)
        bitmap &= ~ones;
      else
        ++ bitmap;
    }
  }

  return bitmap;
}

/* Non-comparator SubsetIndex implemented in terms of comparator-based
 * SubsetIndex.
 */

template <typename Integer, typename InputIterator, typename ForwardIterator>
Integer SubsetIndex(InputIterator subsetBegin, InputIterator subsetEnd,
                    ForwardIterator setBegin, ForwardIterator setEnd) {
  return 
    SubsetIndex<Integer>(subsetBegin, subsetEnd, setBegin, setEnd,
                         std::less<typename std::iterator_traits<InputIterator>::value_type>());
}

#endif