/****************************************************************************
 * File: SplayTree.hh
 * Author: Keith Schwarz (htiek@cs.stanford.edu)
 *
 * An implementation of a dictionary structure backed by a splay tree.  Splay
 * trees, first described by Sleator and Tarjan in their paper "Self-Adjusting
 * Binary Search Trees," are a type of binary search tree with excellent
 * amortized runtime guarantees.  Unlike other balanced search trees, such as
 * red/black trees, AVL trees, or AA trees, the splay tree maintains no
 * explicit balance information or auxiliary data.  Instead, it tries to
 * balance itself whenever an access occurs by moving the most-recently 
 * accessed node up to the root in a process called "splaying."  This splay
 * process works by performing particular tree rotations until the node in
 * question reaches the root.  It is this particular family of rotations, not
 * the fact that the root itself is being rotated, that accounts for the splay
 * tree's amortized O(lg n) runtime for each operation on the tree.  However,
 * this amortized guarantee is not the strongest selling point of the splay
 * tree; this data structure has other strong guarantees as well.  For example,
 * the number of comparisons performed by a splay tree given any sequence of
 * accesses or operations is within a constant factor of the theoretically
 * optimal amount given a fixed-shape binary search tree.
 *
 * The actual rotations involved in a splay tree fall into three categories -
 * "zig," "zig-zig," and "zig-zag."  These cases are described here:
 *
 * 1. "Zig:" If the node's parent is the root, the rotation is
 *
 *                            A        B
 *                           /    -->   \
 *                          B            A
 *
 *    The symmetric case is also considered here.
 *
 * 2. "Zig-Zig:" If the nodes are in the pattern
 *
 *                                 A
 *                                /
 *                               B
 *                              /
 *                             C
 *
 *    They are rotated into
 *
 *                         A         C
 *                        /           \
 *                       B     -->     B
 *                      /               \
 *                     C                 A
 *
 *
 *    This is done by rotating B with A, then C with B.
 *
 * 3. "Zig-Zag:"  If the nodes are in the shape
 *
 *                              A
 *                             /
 *                            B
 *                             \
 *                              C
 *
 *
 *     Then they are rotated to form
 *
 *                     A
 *                    /
 *                   B        ->         C
 *                    \                 / \
 *                     C               B   A
 *
 *     By rotating C with B and then C with A.
 *
 * This splay step is carried out whenever a node is accessed, inserted, or
 * deleted.  In some cases, namely deletion, multiply splays might be required.
 *
 * Another major advantage of the splay operation is that after a tree is
 * splayed at a node, that node ends up at the root.  This allows several
 * complex tree operations, such as joining two trees or splitting a tree in
 * two, to be carried out easily.  The implementation of delete uses this
 * property, for example.
 *
 * This implementation of the splay tree uses it to implement a sorted
 * associative array akin to the STL std::map type.  To support efficient
 * iteration, a the nodes in the splay tree have a doubly-linked list threaded
 * through them in ascending order.  This allows the iterators to scan over the
 * nodes without having to compute successors or resplay the tree.
 */

#ifndef SplayTree_Included
#define SplayTree_Included

#include <algorithm>   // For lexicographical_compare, equal, max
#include <functional>  // For less
#include <utility>     // For pair
#include <iterator>    // For iterator, reverse_iterator
#include <stdexcept>   // For out_of_range

/**
 * A map-like class backed by a splay tree.
 */

template <typename Key, typename Value, typename Comparator = std::less<Key> >
class SplayTree {
public:
  /**
   * Constructor: SplayTree(Comparator comp = Comparator());
   * Usage: SplayTree<string, int> mySplayTree;
   * Usage: SplayTree<string, int> mySplayTree(MyComparisonFunction);
   * -------------------------------------------------------------------------
   * Constructs a new, empty splay tree that uses the indicated comparator to
   * compare keys.
   */

  SplayTree(Comparator comp = Comparator());

  /**
   * Destructor: ~SplayTree();
   * Usage: (implicit)
   * -------------------------------------------------------------------------
   * Destroys the splay tree, deallocating all memory allocated internally.
   */

  ~SplayTree();

  /**
   * Copy functions: SplayTree(const SplayTree& other);
   *                 SplayTree& operator= (const SplayTree& other);
   * Usage: SplayTree<string, int> one = two;
   *        one = two;
   * -------------------------------------------------------------------------
   * Makes this splay tree equal to a deep-copy of some other splay tree.
   */

  SplayTree(const SplayTree& other);
  SplayTree& operator= (const SplayTree& other);

  /**
   * Type: iterator
   * Type: const_iterator
   * -------------------------------------------------------------------------
   * A pair of types that can traverse the elements of a splay tree in ascending
   * order.
   */

  class iterator;
  class const_iterator;

  /**
   * Type: reverse_iterator
   * Type: const_reverse_iterator
   * -------------------------------------------------------------------------
   * A pair of types that can traverse the elements of a splay tree in descending
   * order.
   */

  typedef std::reverse_iterator<iterator> reverse_iterator;
  typedef std::reverse_iterator<const_iterator> const_reverse_iterator;

  /**
   * std::pair<iterator, bool> insert(const Key& key, const Value& value);
   * Usage: mySplayTree.insert("Skiplist", 137);
   * -------------------------------------------------------------------------
   * Inserts the specified key/value pair into the splay tree.  If an entry with
   * the specified key already existed, this function returns false paired
   * with an iterator to the extant value.  If the entry was inserted
   * successfully, returns true paired with an iterator to the new element.
   */

  std::pair<iterator, bool> insert(const Key& key, const Value& value);

  /**
   * bool erase(const Key& key);
   * Usage: mySplayTree.erase("AVL Tree");
   * -------------------------------------------------------------------------
   * Removes the entry from the splay tree with the specified key, if it exists.
   * Returns whether or not an element was erased.  All outstanding iterators
   * remain valid, except for those referencing the deleted element.
   */

  bool erase(const Key& key);

  /**
   * iterator erase(iterator where);
   * Usage: mySplayTree.erase(mySplayTree.begin());
   * -------------------------------------------------------------------------
   * Removes the entry referenced by the specified iterator from the tree,
   * returning an iterator to the next element in the sequence.
   */

  iterator erase(iterator where);

  /**
   * iterator find(const Key& key);
   * const_iterator find(const Key& key);
   * Usage: if (mySplayTree.find("Skiplist") != mySplayTree.end()) { ... }
   * -------------------------------------------------------------------------
   * Returns an iterator to the entry in the splay tree with the specified key, or
   *  end() as as sentinel if it does not exist.
   */

  iterator find(const Key& key);
  const_iterator find(const Key& key) const;

  /**
   * Value& operator[] (const Key& key);
   * Usage: mySplayTree["skiplist"] = 137;
   * -------------------------------------------------------------------------
   * Returns a reference to the value associated with the specified key in the
   * splay tree.  If the key is not contained in the splay tree, it will be inserted
   * into the splay tree with a default-constructed Entry as its value.
   */

  Value& operator[] (const Key& key);

  /**
   * Value& at(const Key& key);
   * const Value& at(const Key& key) const;
   * Usage: mySplayTree.at("skiplist") = 137;
   * -------------------------------------------------------------------------
   * Returns a reference to the value associated with the specified key,
   * throwing a std::out_of_range exception if the key does not exist in the
   * splay tree.
   */

  Value& at(const Key& key);
  const Value& at(const Key& key) const;

  /**
   * (const_)iterator begin() (const);
   * (const_)iterator end() (const);
   * Usage: for (SplayTree<string, int>::iterator itr = t.begin(); 
   *             itr != t.end(); ++itr) { ... }
   * -------------------------------------------------------------------------
   * Returns iterators delineating the full contents of the splay tree.  Each
   * iterator acts as a pointer to a std::pair<const Key, Entry>.
   */

  iterator begin();
  iterator end();
  const_iterator begin() const;
  const_iterator end() const;

  /**
   * (const_)reverse_iterator rbegin() (const);
   * (const_)reverse_iterator rend() (const);
   * Usage: for (SplayTree<string, int>::reverse_iterator itr = s.rbegin(); 
   *             itr != s.rend(); ++itr) { ... }
   * -------------------------------------------------------------------------
   * Returns iterators delineating the full contents of the splay tree in reverse
   * order.
   */

  reverse_iterator rbegin();
  reverse_iterator rend();
  const_reverse_iterator rbegin() const;
  const_reverse_iterator rend() const;

  /**
   * (const_)iterator lower_bound(const Key& key) (const);
   * (const_)iterator upper_bound(const Key& key) (const);
   * Usage: for (SplayTree<string, int>::iterator itr = t.lower_bound("AVL");
   *             itr != t.upper_bound("skiplist"); ++itr) { ... }
   * -------------------------------------------------------------------------
   * lower_bound returns an iterator to the first element in the splay tree whose
   * key is at least as large as key.  upper_bound returns an iterator to the
   * first element in the splay tree whose key is strictly greater than key.
   */

  iterator lower_bound(const Key& key);
  iterator upper_bound(const Key& key);
  const_iterator lower_bound(const Key& key) const;
  const_iterator upper_bound(const Key& key) const;

  /**
   * std::pair<(const_)iterator, (const_)iterator> 
   *    equal_range(const Key& key) (const);
   * Usage: std::pair<SplayTree<int, int>::iterator, SplayTree<int, int>::iterator>
   *          range = t.equal_range("AVL");
   * -------------------------------------------------------------------------
   * Returns a range of iterators spanning the unique copy of the entry whose
   * key is key if it exists, and otherwise a pair of iterators both pointing
   * to the spot in the splay tree where the element would be if it were.
   */

  std::pair<iterator, iterator> equal_range(const Key& key);
  std::pair<const_iterator, const_iterator> equal_range(const Key& key) const;

  /**
   * size_t size() const;
   * Usage: cout << "SplayTree contains " << s.size() << " entries." << endl;
   * -------------------------------------------------------------------------
   * Returns the number of elements stored in the splay tree.
   */

  size_t size() const;

  /**
   * bool empty() const;
   * Usage: if (s.empty()) { ... }
   * -------------------------------------------------------------------------
   * Returns whether the splay tree contains no elements.
   */

  bool empty() const;

  /**
   * void swap(SplayTree& other);
   * Usage: one.swap(two);
   * -------------------------------------------------------------------------
   * Exchanges the contents of this splay tree and some other splay tree.  All
   * outstanding iterators are invalidated.
   */

  void swap(SplayTree& other);

private:
  /* A type representing a node in the splay tree. */
  struct Node {
    std::pair<const Key, Value> mValue; // The actual value stored here

    /* The children are stored in an array to make it easier to implement tree
     * rotations.  The first entry is the left child, the second the right.
     */

    Node* mChildren[2];

    /* Pointer to the parent node. */
    Node* mParent;

    /* Pointer to the next and previous node in the sorted sequence. */
    Node* mNext, *mPrev;

    /* Constructor sets up the value to the specified key/value pair. */
    Node(const Key& key, const Value& value);
  };

  /* A pointer to the first and last elements of the splay tree. */
  Node* mHead, *mTail;

  /* A pointer to the root of the tree.  This is marked mutable because the
   * splay operation needs to change this value, even though it doesn't change
   * the observable state of the tree.
   */

  mutable Node* mRoot;

  /* The comparator to use when storing elements. */
  Comparator mComp;

  /* The number of elements in the list. */
  size_t mSize;

  /* A utility base class for iterator and const_iterator which actually
   * supplies all of the logic necessary for the two to work together.  The
   * parameters are the derived type, the type of a pointer being visited, and
   * the type of a reference being visited.  This uses the Curiously-Recurring
   * Template Pattern to work correctly.
   */

  template <typename DerivedType, typename Pointer, typename Reference>
  class IteratorBase;
  template <typename DerivedType, typename Pointer, typename Reference>
  friend class IteratorBase;

  /* Make iterator and const_iterator friends as well so they can use the
   * Node type.
   */

  friend class iterator;
  friend class const_iterator;

  /* A utility function to perform a tree rotation to pull the child above its
   * parent.  This function is semantically const but not bitwise const, since
   * it changes the structure but not the content of the elements being
   * stored.
   */

  void rotateUp(Node* child) const;

  /* A utility function that performs a splay starting at the given node.  The
   * root of the tree and the tree's internal structure will be changed, but
   * the order of the nodes in the linked list will not.
   */

  void splay(Node* where) const;

  /* A utility function which does a BST search on the tree, looking for the
   * indicated node.  The return result is a pair of pointers, the first of
   * which is the node being searched for, or NULL if that node is not found.
   * The second node is that node's parent, which is either the parent of the
   * found node, or the last node visited in the tree before NULL was found
   * if the node was not found.  No splaying is performed.
   *
   * The pointers returned here are Node*s independently of whether the
   * receiver object is const.  It is up to the implementer to ensure that
   * this function does not subvert constness.
   */

  std::pair<Node*, Node*> findNode(const Key& key) const;

  /* A utility function which, given two splay trees 'left' and 'right' where
   * each value in 'left' is smaller than any value in 'right,' destructively
   * modifies the two trees by joining them together into a single splay tree
   * containing all of the nodes in each.  It then returns the root of this
   * new tree.
   */

  Node* mergeTrees(Node* left, Node* right) const;

  /* A utility function which, given a node and the node to use as its parent,
   * recursively deep-copies the tree rooted at that node, using the parent
   * node as the new tree's parent.
   */

  static Node* cloneTree(Node* toClone, Node* parent);

  /* A utility function which, given a tree and a pointer to the predecessor
   * of that tree, rewires the linked list in that tree to represent an
   * inorder traversal.  No fields are modified.  The return value is the node
   * with the highest key.
   */

  static Node* rethreadLinkedList(Node* root, Node* predecessor);
};

/* Comparison operators for SplayTrees. */
template <typename Key, typename Value, typename Comparator>
bool operator<  (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs);
template <typename Key, typename Value, typename Comparator>
bool operator<= (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs);
template <typename Key, typename Value, typename Comparator>
bool operator== (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs);
template <typename Key, typename Value, typename Comparator>
bool operator!= (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs);
template <typename Key, typename Value, typename Comparator>
bool operator>= (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs);
template <typename Key, typename Value, typename Comparator>
bool operator>  (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs);

/* * * * * Implementation Below This Point * * * * */

/* Definition of the IteratorBase type, which is used to provide a common
 * implementation for iterator and const_iterator.
 */

template <typename Key, typename Value, typename Comparator>
template <typename DerivedType, typename Pointer, typename Reference>
class SplayTree<Key, Value, Comparator>::IteratorBase {
public:
  /* Utility typedef to talk about nodes. */
  typedef typename SplayTree<Key, Value, Comparator>::Node Node;

  /* Advance operators just construct derived type instances of the proper
   * type, then advance them.
   */

  DerivedType& operator++ () {
    mCurr = mCurr->mNext;

    /* Downcast to our actual type. */
    return static_cast<DerivedType&>(*this);
  }
  const DerivedType operator++ (int) {
    /* Copy our current value by downcasting to our real type. */
    DerivedType result = static_cast<DerivedType&>(*this);

    /* Advance to the next element. */
    ++*this;

    /* Hand back the cached value. */
    return result;
  }

  /* Backup operators work on the same principle. */
  DerivedType& operator-- () {
    /* If the current pointer is NULL, it means that we've walked off the end
     * of the structure and need to back up a step.
     */

    if (mCurr == NULL) {
      mCurr = mOwner->mTail;
    }
    /* Otherwise, just back up a step. */
    else {
      mCurr = mCurr->mPrev;
    }

    /* Downcast to our actual type. */
    return static_cast<DerivedType&>(*this);
  }
  const DerivedType operator-- (int) {
    /* Copy our current value by downcasting to our real type. */
    DerivedType result = static_cast<DerivedType&>(*this);

    /* Back up a step. */
    --*this;

    /* Hand back the cached value. */
    return result;
  }

  /* Equality and disequality operators are parameterized - we'll allow anyone
   * whose type is IteratorBase to compare with us.  This means that we can
   * compare both iterator and const_iterator against one another.
   */

  template <typename DerivedType2, typename Pointer2, typename Reference2>
  bool operator== (const IteratorBase<DerivedType2, Pointer2, Reference2>& rhs) {
    /* Just check the underlying pointers, which (fortunately!) are of the 
     * same type.
     */

    return mOwner == rhs.mOwner && mCurr == rhs.mCurr;
  }
  template <typename DerivedType2, typename Pointer2, typename Reference2>
  bool operator!= (const IteratorBase<DerivedType2, Pointer2, Reference2>& rhs) {
    /* We are disequal if equality returns false. */
    return !(*this == rhs);
  }

  /* Pointer dereference operator hands back a reference. */
  Reference operator* () const {
    return mCurr->mValue;
  }
  
  /* Arrow operator returns a pointer. */
  Pointer operator-> () const {
    /* Use the standard "&**this" trick to dereference this object and return
     * a pointer to the referenced value.
     */

    return &**this;
  }

protected:
  /* Which SplayTree we belong to.  This pointer is const even though we are
   * possibly allowing ourselves to modify the splay tree elements to avoid having
   * to duplicate this logic once again for const vs. non-const iterators.
   */

  const SplayTree* mOwner;

  /* Where we are in the list. */
  Node* mCurr;

  /* In order for equality comparisons to work correctly, all IteratorBases
   * must be friends of one another.
   */

  template <typename Derived2, typename Pointer2, typename Reference2>
  friend class IteratorBase;

  /* Constructor sets up the splay tree and node pointers appropriately. */
  IteratorBase(const SplayTree* owner = NULL, Node* curr = NULL) 
  : mOwner(owner), mCurr(curr) {
    // Handled in initializer list
  }
};

/* iterator and const_iterator implementations work by deriving off of
 * IteratorBase, passing in parameters that make all the operators work.
 * Additionally, we inherit from std::iterator to import all the necessary
 * typedefs to qualify as an iterator.
 */

template <typename Key, typename Value, typename Comparator>
class SplayTree<Key, Value, Comparator>::iterator:
  public std::iterator< std::bidirectional_iterator_tag,
                        std::pair<const Key, Value> >,
  public IteratorBase<iterator,                       // Our type
                      std::pair<const Key, Value>*,   // Reference type
                      std::pair<const Key, Value>&> { // Pointer type 
public:
  /* Default constructor forwards NULL to base implicity. */
  iterator() {
    // Nothing to do here.
  }

  /* All major operations inherited from the base type. */

private:
  /* Constructor for creating an iterator out of a raw node just forwards this
   * argument to the base type.  This line is absolutely awful because the
   * type of the base is so complex.
   */

  iterator(const SplayTree* owner,
           typename SplayTree<Key, Value, Comparator>::Node* node) :
    IteratorBase<iterator,
                 std::pair<const Key, Value>*,
                 std::pair<const Key, Value>&>(owner, node) {
    // Handled by initializer list
  }

  /* Make the SplayTree a friend so it can call this constructor. */
  friend class SplayTree;

  /* Make const_iterator a friend so we can do iterator-to-const_iterator
   * conversions.
   */

  friend class const_iterator;
};

/* Same as above, but with const added in. */
template <typename Key, typename Value, typename Comparator>
class SplayTree<Key, Value, Comparator>::const_iterator:
  public std::iterator< std::bidirectional_iterator_tag,
                        const std::pair<const Key, Value> >,
  public IteratorBase<const_iterator,                       // Our type
                      const std::pair<const Key, Value>*,   // Reference type
                      const std::pair<const Key, Value>&> { // Pointer type 
public:
  /* Default constructor forwards NULL to base implicity. */
  const_iterator() {
    // Nothing to do here.
  }

  /* iterator conversion constructor forwards the other iterator's base fields
   * to the base class.
   */

  const_iterator(iterator itr) :
    IteratorBase<const_iterator,
                 const std::pair<const Key, Value>*,
                 const std::pair<const Key, Value>&>(itr.mOwner, itr.mCurr) {
    // Handled in initializer list
  }

  /* All major operations inherited from the base type. */

private:
  /* See iterator implementation for details about what this does. */
  const_iterator(const SplayTree* owner,
                 typename SplayTree<Key, Value, Comparator>::Node* node) :
    IteratorBase<const_iterator,
                 const std::pair<const Key, Value>*,
                 const std::pair<const Key, Value>&>(owner, node) {
    // Handled by initializer list
  }
  
  /* Make the SplayTree a friend so it can call this constructor. */
  friend class SplayTree;
};

/**** SplayTree::Node Implementation. ****/

/* Constructor sets up the value and priority, but leaves everything else
 * unset.  This is mostly to allow the fields to be const while still getting
 * the code to compile.
 */

template <typename Key, typename Value, typename Comparator>
SplayTree<Key, Value, Comparator>::Node::Node(const Key& key,
                                          const Value& value) 
  : mValue(key, value) {
  // Handled in initializer list.
}

/**** SplayTree Implementation ****/

/* Constructor sets up a new, empty SplayTree. */
template <typename Key, typename Value, typename Comparator>
SplayTree<Key, Value, Comparator>::SplayTree(Comparator comp) : mComp(comp) {
  /* Initially, the list of elements is empty and the tree is NULL. */
  mHead = mTail = mRoot = NULL;

  /* The tree is created empty. */
  mSize = 0;
}

/* Destructor walks the linked list of elements, deleting all nodes it
 * encounters.
 */

template <typename Key, typename Value, typename Comparator>
SplayTree<Key, Value, Comparator>::~SplayTree() {
  /* Start at the head of the list. */
  Node* curr = mHead;
  while (curr != NULL) {
    /* Cache the next value; we're about to blow up our only pointer to it. */
    Node* next = curr->mNext;

    /* Free memory, then go to the next node. */
    delete curr;
    curr = next;
  }
}

/* Inserting a node works by walking down the tree until the insert point is
 * found, adding the value, then splaying it up to the root.  If the key to be
 * inserted already exists, then it is splayed instead.
 */

template <typename Key, typename Value, typename Comparator>
std::pair<typename SplayTree<Key, Value, Comparator>::iterator, bool>
SplayTree<Key, Value, Comparator>::insert(const Key& key, const Value& value) {
  /* Recursively walk down the tree from the root, looking for where the value
   * should go.  In the course of doing so, we'll maintain some extra
   * information about the node's successor and predecessor so that we can
   * wire the new node in in O(1) time.
   *
   * The information that we'll need will be the last nodes at which we
   * visited the left and right child.  This is because if the new node ends
   * up as a left child, then its predecessor is the last ancestor on the path
   * where we followed its right pointer, and vice-versa if the node ends up
   * as a right child.
   */

  Node* lastLeft = NULL, *lastRight = NULL;
  
  /* Also keep track of our current location as a pointer to the pointer in
   * the tree where the node will end up, which allows us to insert the node
   * by simply rewiring this pointer.
   */

  Node** curr   = &mRoot;

  /* Also track the last visited node. */
  Node*  parent = NULL;

  /* Now, do a standard binary tree insert.  If we ever find the node, we can
   * stop early.
   */

  while (*curr != NULL) {
    /* Update the parent to be this node, since it's the last one visited. */
    parent = *curr;

    /* Check whether we belong in the left subtree. */
    if (mComp(key, (*curr)->mValue.first)) {
      lastLeft = *curr;
      curr = &(*curr)->mChildren[0];
    }
    /* ... or perhaps the right subtree. */
    else if (mComp((*curr)->mValue.first, key)) {
      lastRight = *curr; // Last visited node where we went right.
      curr = &(*curr)->mChildren[1];
    }
    /* Otherwise, the key must already exist in the tree.  Splay it to the
     * root, and then return a pointer to it.
     */

    else {
      /* Because we're about to do a splay operation to rewire the tree, we
       * need to cache what node we're going to return, not just the pointer
       * to it.  In particular, if we cache a pointer to the pointer to the
       * node, then after the splay that pointer might no longer be valid.
       */

      Node* toReturn = *curr;
      splay(toReturn);
      return std::make_pair(iterator(this, toReturn), false);
    }
  }

  /* At this point we've found our insertion point and can create the node
   * we're going to wire in.
   */

  Node* toInsert = new Node(key, value);
  
  /* Splice it into the tree. */
  toInsert->mParent = parent;
  *curr = toInsert;

  /* The new node has no children. */
  toInsert->mChildren[0] = toInsert->mChildren[1] = NULL;

  /* Wire this node into the linked list in-between its predecessor and
   * successor in the tree.  The successor is the last node where we went
   * left, and the predecessor is the last node where we went right.
   */

  toInsert->mNext = lastLeft;
  toInsert->mPrev = lastRight;

  /* Update the previous pointer of the next entry, or change the list tail
   * if there is no next entry.
   */

  if (toInsert->mNext)
    toInsert->mNext->mPrev = toInsert;
  else
    mTail = toInsert;

  /* Update the next pointer of the previous entry similarly. */
  if (toInsert->mPrev)
    toInsert->mPrev->mNext = toInsert;
  else
    mHead = toInsert;
  
  /* Splay this new node back up to the root. */
  splay(toInsert);

  /* Increase the size of the tree, since we just added a node. */
  ++mSize;

  /* Hand back an iterator to the new element, along with a notification that
   * it was inserted correctly.
   */

  return std::make_pair(iterator(this, toInsert), true);
}

/* To perform a tree rotation, we identify whether we're doing a left or
 * right rotation, then rewrite pointers as follows:
 *
 * In a right rotation, we do the following:
 *
 *      B            A
 *     / \          / \
 *    A   2   -->  0   B
 *   / \              / \
 *  0   1            1   2
 *
 * In a left rotation, this runs backwards.
 *
 * The reason that we've implemented the nodes as an array of pointers rather
 * than using two named pointers is that the logic is symmetric.  If the node
 * is its left child, then its parent becomes its right child, and the node's
 * right child becomes the parent's left child.  If the node is its parent's
 * right child, then the node's parent becomes its left child and the node's
 * left child becomes the parent's right child.  In other words, the general
 * formula is
 *
 * If the node is its parent's SIDE child, then the parent becomes that node's
 * OPPOSITE-SIDE child, and the node's OPPOSITE-SIDE child becomes the
 * parent's SIDE child.
 *
 * This code also updates the root if the tree root gets rotated out.
 */

template <typename Key, typename Value, typename Comparator>
void SplayTree<Key, Value, Comparator>::rotateUp(Node* node) const {
  /* Determine which side the node is on.  It's on the left (side 0) if the
   * parent's first pointer matches it, and is on the right (side 1) if the
   * node's first pointer doesn't match it.  This is, coincidentally, whether
   * the node is not equal to the first pointer of its root.
   */

  const int side = (node != node->mParent->mChildren[0]);
  
  /* The other side is the logical negation of the side itself. */
  const int otherSide = !side;

  /* Cache the displaced child and parent of the current node. */
  Node* child  = node->mChildren[otherSide];
  Node* parent = node->mParent;
  
  /* Shuffle pointers around to make the node the parent of its parent. */
  node->mParent = parent->mParent;
  node->mChildren[otherSide] = parent;

  /* Shuffle around pointers so that the parent takes on the displaced
   * child.
   */

  parent->mChildren[side] = child;
  if (child)
    child->mParent = parent;

  /* Update the grandparent (if any) so that its child is now the rotated
   * element rather than the parent.  If there is no grandparent, the node is
   * now the root.
   */

  if (parent->mParent) {
    const int parentSide = (parent != parent->mParent->mChildren[0]);
    parent->mParent->mChildren[parentSide] = node;
  } else
    mRoot = node;

  /* In either case, change the parent so that it now treats the node as the
   * parent.
   */

  parent->mParent = node;
}

/* Implementation of the splay operation for moving a node up to the root by a
 * series of intelligent rotations.  This logic is a bit tricky because we
 * have four cases to check: either we're the root, or the zig case, or the
 * zig-zig case, or the zig-zag case.
 */

template <typename Key, typename Value, typename Comparator>
void SplayTree<Key, Value, Comparator>::splay(Node* node) const {
  /* Continue moving the node upward until it becomes the root.  We'll also
   * check here if we were asked to splay the NULL tree, which is a no-op.
   */

  while (node && node->mParent) {
    /* For simplicity, keep track of the parent of this node. */
    Node* parent = node->mParent;

    /* Zig case: If the parent is the root, do just one rotation. */
    if (parent->mParent == NULL)
      rotateUp(node);

    /* Zig-zig case: If the node and its parent are either both left children
     * or both right children, do a zig-zig rotation by rotating the parent
     * with its parent, then the node with its own parent.
     *
     * To check whether the node is a left or right child, we use a similar
     * trick to the rotateUp method of comparing the node with its parent's
     * left child.
     */

    else if ((parent->mParent->mChildren[0] == parent) ==
        (node->mParent->mChildren[0] == node)) {
      rotateUp(parent);
      rotateUp(node);
    }

    /* Otherwise, we must be in the zig-zag case, in which case we rotate the
     * node with its parent twice.
     */

    else {
      rotateUp(node);
      rotateUp(node);
    }
  }
}

/* const version of find works by doing a standard BST search for the node in
 * question, then splaying the tree to that node.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_iterator
SplayTree<Key, Value, Comparator>::find(const Key& key) const {
  /* Do a standard BST search to locate the node and its ancestor. */
  std::pair<Node*, Node*> result = findNode(key);

  /* Do the splay step.  If we found the node, splay it up to the root.  If
   * not, then splay to the root the last node we encountered.
   */

  splay(result.first? result.first : result.second);

  /* Wrap up whatever we found, even if it's NULL, and hand it back. */
  return const_iterator(this, result.first);
}

/* Non-const version of find implemented in terms of const find. */
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::iterator
SplayTree<Key, Value, Comparator>::find(const Key& key) {
  /* Get the underlying const_iterator by calling the const version of this
   * function.
   */

  const_iterator itr = static_cast<const SplayTree*>(this)->find(key);

  /* Strip off the constness by wrapping it up as a raw iterator. */
  return iterator(itr.mOwner, itr.mCurr);
}

/* findNode just does a standard BST lookup, recording the last node that was
 * found before the one that was ultimately returned.
 */

template <typename Key, typename Value, typename Comparator>
std::pair<typename SplayTree<Key, Value, Comparator>::Node*,
          typename SplayTree<Key, Value, Comparator>::Node*>
SplayTree<Key, Value, Comparator>::findNode(const Key& key) const {
  /* Start the search at the root and work downwards.  Keep track of the last
   * node we visited so that we can do a splay even if we walk off the tree.
   */

  Node* curr = mRoot, *prev = NULL;
  while (curr != NULL) {
    /* Update the prev pointer so that it tracks the last node we visited. */
    prev = curr;

    /* If the key is less than this node, go left. */
    if (mComp(key, curr->mValue.first))
      curr = curr->mChildren[0];
    /* Otherwise if the key is greater than the node, go right. */
    else if (mComp(curr->mValue.first, key))
      curr = curr->mChildren[1];
    /* Otherwise, we found the node.  Return that node and its parent as the
     * pair in question.  We explicitly use the parent here instead of prev
     * since the first part of this loop updates prev to be equal to curr.
     */

    else
      return std::make_pair(curr, curr->mParent);
  }

  /* If we ended up here, then we know that we didn't find the node in
   * question.  Handing back the pair of NULL and the most-recently-visited
   * node.  Note that due to the fact that NULL is #defined as zero, we have
   * to explicitly cast it to a Node* so that the template argument deduction
   * will work correctly; omitting this cast yields a pair<int, Node*>, which
   * gives a type error.
   */

  return std::make_pair((Node*)NULL, prev);
}

/* begin and end return iterators wrapping the head of the list or NULL,
 * respectively.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::iterator
SplayTree<Key, Value, Comparator>::begin() {
  return iterator(this, mHead);
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_iterator
SplayTree<Key, Value, Comparator>::begin() const {
  return iterator(this, mHead);
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::iterator
SplayTree<Key, Value, Comparator>::end() {
  return iterator(this, NULL);
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_iterator
SplayTree<Key, Value, Comparator>::end() const {
  return iterator(this, NULL);
}

/* rbegin and rend return wrapped versions of end() and begin(),
 * respectively.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::reverse_iterator
SplayTree<Key, Value, Comparator>::rbegin() {
  return reverse_iterator(end());
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_reverse_iterator
SplayTree<Key, Value, Comparator>::rbegin() const {
  return const_reverse_iterator(end());
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::reverse_iterator
SplayTree<Key, Value, Comparator>::rend() {
  return reverse_iterator(begin());
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_reverse_iterator
SplayTree<Key, Value, Comparator>::rend() const {
  return const_reverse_iterator(begin());
}

/* size just returns the cached size of the splay tree. */
template <typename Key, typename Value, typename Comparator>
size_t SplayTree<Key, Value, Comparator>::size() const {
  return mSize;
}

/* empty returns whether the size is zero. */
template <typename Key, typename Value, typename Comparator>
bool SplayTree<Key, Value, Comparator>::empty() const {
  return size() == 0;
}

/* Erasing an element works in three steps:
 *
 * 1. Splay the element erase up to the root.  Our tree is now the element to
 *    delete with left and right subtrees of the elements to retain.
 * 2. Use the merge operation to join these two subtrees together into a tree
 *    holding all the values from the original tree except the one to remove.
 * 3. Delete the root node.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::iterator
SplayTree<Key, Value, Comparator>::erase(iterator where) {
  /* Extract the node pointer from the iterator. */
  Node* node = where.mCurr;

  /* Begin by splaying the element to remove up to the root. */
  splay(node);

  /* Detach the left and right subtrees from this node. */
  Node* lhs = node->mChildren[0];
  Node* rhs = node->mChildren[1];

  /* Make these nodes no longer treat the root element as their parents. */
  if (lhs) lhs->mParent = NULL;
  if (rhs) rhs->mParent = NULL;

  /* Merge these trees together into a tree holding the rest of the entries. */
  mRoot = mergeTrees(lhs, rhs);

  /* We've now removed the node in question from the tree structure, and now
   * we need to remove it from the doubly-linked list.
   */


  /* If there is a next node, wire its previous pointer around the current
   * node.  Otherwise, the tail just changed.
   */

  if (node->mNext)
    node->mNext->mPrev = node->mPrev;
  else
    mTail = node->mPrev;

  /* If there is a previous node, wite its next pointer around the current
   * node.  Otherwise, the head just changed.
   */

  if (node->mPrev)
    node->mPrev->mNext = node->mNext;
  else
    mHead = node->mNext;

  /* Since we need to return an iterator to the element in the tree after this
   * one, we'll cache the next pointer of the node to delete.  It won't be
   * available after we delete the node.
   */

  iterator result(this, node->mNext);

  /* Free the node's resources. */
  delete node;

  /* Decrease the logical size of this structure so we don't keep track of the
   * number of elements incorrectly.
   */

  --mSize;
  return result;
}

/* Merging two trees works by pulling the largest element of the left subtree
 * up to the root with a splay operation, then making its right child the
 * right subtree.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::Node*
SplayTree<Key, Value, Comparator>::mergeTrees(Node* lhs, Node* rhs) const {
  /* Edge cases - if either the lhs or rhs are empty, just return the other
   * tree.  After all, the merge of any tree and the empty tree is just that
   * tree.
   */

  if (lhs == NULL) return rhs;
  if (rhs == NULL) return lhs;

  /* Find the largest node in the left-hand tree.  This works by marching down
   * the right spine while it's possible to do so.
   */

  Node* maxElem = lhs;
  while (maxElem->mChildren[1] != NULL) maxElem = maxElem->mChildren[1];

  /* Splay that node up to the root. */
  splay(maxElem);

  /* Make the right tree a child of this element. */
  maxElem->mChildren[1] = rhs;
  rhs->mParent = maxElem;

  /* Hand back this node as the root. */
  return maxElem;
}

/* Erasing a single value just calls find to locate the element and the
 * iterator version of erase to remove it.
 */

template <typename Key, typename Value, typename Comparator>
bool SplayTree<Key, Value, Comparator>::erase(const Key& key) {
  /* Look up where this node is, then remove it if it exists. */
  iterator where = find(key);
  if (where == end()) return false;

  erase(where);
  return true;
}

/* Square brackets implemented in terms of insert(). */
template <typename Key, typename Value, typename Comparator>
Value& SplayTree<Key, Value, Comparator>::operator[] (const Key& key) {
  /* Call insert to get a pair of an iterator and a bool.  Look at the
   * iterator, then consider its second field.
   */

  return insert(key, Value()).first->second;
}

/* at implemented in terms of find. */
template <typename Key, typename Value, typename Comparator>
const Value& SplayTree<Key, Value, Comparator>::at(const Key& key) const {
  /* Look up the key, failing if we can't find it. */
  const_iterator result = find(key);
  if (result == end())
    throw std::out_of_range("Key not found in splay tree.");

  /* Otherwise just return the value field. */
  return result->second;
}

/* non-const at implemented in terms of at using the const_cast/static_cast
 * trick.
 */

template <typename Key, typename Value, typename Comparator>
Value& SplayTree<Key, Value, Comparator>::at(const Key& key) {
  return const_cast<Value&>(static_cast<const SplayTree*>(this)->at(key));
}

/* The copy constructor is perhaps the most complex part of this entire
 * implementation.  It works in two passes.  First, the tree structure itself
 * is duplicated, without paying any attention to the next and previous
 * pointers threaded through.  Next, we run a recursive pass over the cloned
 * tree, fixing up all of the next and previous pointers as we go.
 */

template <typename Key, typename Value, typename Comparator>
SplayTree<Key, Value, Comparator>::SplayTree(const SplayTree& other) {
  /* Start off with the simple bits - copy over the size field and 
   * comparator. 
   */

  mSize = other.mSize;
  mComp = other.mComp;

  /* Clone the tree structure. */
  mRoot = cloneTree(other.mRoot, NULL);

  /* Rectify the linked list. */
  rethreadLinkedList(mRoot, NULL);

  /* Finally, fix up the first and last pointers of the list by looking for
   * the smallest and largest elements in the tree.
   */

  mTail = mHead = mRoot;
  while (mHead && mHead->mChildren[0]) mHead = mHead->mChildren[0];
  while (mTail && mTail->mChildren[1]) mTail = mTail->mChildren[1];
}

/* Cloning a tree is a simple structural recursion. */
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::Node*
SplayTree<Key, Value, Comparator>::cloneTree(Node* toClone, Node* parent) {
  /* Base case: the clone of the empty tree is that tree itself. */
  if (toClone == NULL) return NULL;

  /* Create a copy of the node, moving over the priorities and key/value
   * pair.
   */

  Node* result = new Node(toClone->mValue.first, toClone->mValue.second);

  /* Recursively clone the subtrees. */
  for (int i = 0; i < 2; ++i)
    result->mChildren[i] = cloneTree(toClone->mChildren[i], result);

  /* Set the parent. */
  result->mParent = parent;

  return result;
}

/* Fixing up the doubly-linked list is a bit tricky.  The function acts as an
 * inorder traversal.  We first fix up the left subtree, getting a pointer to
 * the node holding the largest value in that subtree (the predecessor of this
 * node).  We then chain the current node into the linked list, then fix up
 * the nodes to the right (which have the current node as their predecessor).
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::Node*
SplayTree<Key, Value, Comparator>::rethreadLinkedList(Node* root, Node* predecessor) {
  /* Base case: if the root is null, then the largest element visited so far
   * is whatever we were told it was.
   */

  if (root == NULL) return predecessor;

  /* Otherwise, recursively fix up the left subtree using the actual
   * predecessor.  Store the return value as the new predecessor.
   */

  predecessor = rethreadLinkedList(root->mChildren[0], predecessor);

  /* Add ourselves to the linked list. */
  root->mPrev = predecessor;
  if (predecessor)
    predecessor->mNext = root;
  root->mNext = NULL;

  /* Recursively invoke on the right subtree, passing in this node as the
   * predecessor.
   */

  return rethreadLinkedList(root->mChildren[1], root);
}

/* Assignment operator implemented using copy-and-swap. */
template <typename Key, typename Value, typename Comparator>
SplayTree<Key, Value, Comparator>&
SplayTree<Key, Value, Comparator>::operator= (const SplayTree& other) {
  SplayTree clone = other;
  swap(clone);
  return *this;
}

/* swap just does an element-by-element swap. */
template <typename Key, typename Value, typename Comparator>
void SplayTree<Key, Value, Comparator>::swap(SplayTree& other) {
  /* Use std::swap to get the job done. */
  std::swap(mRoot, other.mRoot);
  std::swap(mSize, other.mSize);
  std::swap(mHead, other.mHead);
  std::swap(mTail, other.mTail);
  std::swap(mComp, other.mComp);
}

/* lower_bound works by walking down the tree to where the node belongs.  If
 * it's in the tree, then it's its own lower bound.  Otherwise, we either
 * found the predecessor or successor of the node in question, and correct it
 * to the resulting node.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_iterator
SplayTree<Key, Value, Comparator>::lower_bound(const Key& key) const {
  /* One unusual edge case that complicates the logic here is what to do if
   * the tree is empty.  If this happens, then the lower_bound is end().
   */

  if (empty()) return end();

  /* Do a find operation, then resplay as in find(). */
  std::pair<Node*, Node*> result = findNode(key);
  splay(result.first? result.first : result.second);

  /* If we found the node we wanted, we can just wrap it up as an iterator. */
  if (result.first)
    return iterator(this, result.first);

  /* Otherwise, the value isn't here, but we do know the value in the tree
   * that would be its parent.  This value is therefore either the predecessor
   * or the successor of the value in question.  If it's the predecessor, then
   * we need to advance it forward one step to get the smallest value greater
   * than the indicated key.  Note that we can assume that there is some
   * predecessor, since we know that the tree is not empty.
   *
   * To check whether we're looking at the predecessor, we're curious whether
   * the key field of the value of the node of the second Node*.  Phew!
   */

  if (mComp(result.second->mValue.first, key))
    result.second = result.second->mNext;
  
  return iterator(this, result.second);
}

/* Non-const version of this function implemented by calling the const version
 * and stripping constness.
 */

template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::iterator
SplayTree<Key, Value, Comparator>::lower_bound(const Key& key) {
  /* Call the const version to get the answer. */
  const_iterator result = static_cast<const SplayTree*>(this)->lower_bound(key);

  /* Rewrap it in a regular iterator to remove constness. */
  return iterator(result.mOwner, result.mCurr);
}

/* equal_range looks up where the node should be.  If it finds it, it hands
 * back iterators spanning it.  If not, it just hands back two iterators to the
 * same spot.
 */

template <typename Key, typename Value, typename Comparator>
std::pair<typename SplayTree<Key, Value, Comparator>::const_iterator,
          typename SplayTree<Key, Value, Comparator>::const_iterator>
SplayTree<Key, Value, Comparator>::equal_range(const Key& key) const {
  /* Call lower_bound to find out where we should start looking. */
  std::pair<const_iterator, const_iterator> result;
  result.first = result.second = lower_bound(key);

  /* If we hit the end, we're done. */
  if (result.first == end()) return result;

  /* Otherwise, check whether the iterator we found matches the value.  If so,
   * bump the second iterator one step.
   */

  if (!mComp(key, result.second->first))
    ++result.second;

  return result;
}

/* Non-const version calls the const version, then strips off constness. */
template <typename Key, typename Value, typename Comparator>
std::pair<typename SplayTree<Key, Value, Comparator>::iterator,
          typename SplayTree<Key, Value, Comparator>::iterator>
SplayTree<Key, Value, Comparator>::equal_range(const Key& key) {
  /* Invoke const version to get the iterators. */
  std::pair<const_iterator, const_iterator> result =
    static_cast<const SplayTree*>(this)->equal_range(key);

  /* Unwrap into regular iterators. */
  return std::make_pair(iterator(result.first.mOwner,  result.first.mCurr),
                        iterator(result.second.mOwner, result.second.mCurr));
}

/* upper_bound just calls equal_range and returns the second value. */
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::iterator
SplayTree<Key, Value, Comparator>::upper_bound(const Key& key) {
  return equal_range(key).second;
}
template <typename Key, typename Value, typename Comparator>
typename SplayTree<Key, Value, Comparator>::const_iterator
SplayTree<Key, Value, Comparator>::upper_bound(const Key& key) const {
  return equal_range(key).second;
}

/* Comparison operators == and < use the standard STL algorithms. */
template <typename Key, typename Value, typename Comparator>
bool operator<  (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs) {
  return std::lexicographical_compare(lhs.begin(), lhs.end(),
                                      rhs.begin(), rhs.end());
}
template <typename Key, typename Value, typename Comparator>
bool operator== (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs) {
  return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), 
                                                rhs.begin());
}

/* Remaining comparisons implemented in terms of the above comparisons. */
template <typename Key, typename Value, typename Comparator>
bool operator<= (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs) {
  /* x <= y   iff !(x > y)   iff !(y < x) */
  return !(rhs < lhs);
}
template <typename Key, typename Value, typename Comparator>
bool operator!= (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs) {
  return !(lhs == rhs);
}
template <typename Key, typename Value, typename Comparator>
bool operator>= (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs) {
  /* x >= y   iff !(x < y) */
  return !(lhs < rhs);
}
template <typename Key, typename Value, typename Comparator>

bool operator>  (const SplayTree<Key, Value, Comparator>& lhs,
                 const SplayTree<Key, Value, Comparator>& rhs) {
  /* x > y iff y < x */
  return rhs < lhs;
}

#endif