[libcxx] Fix PR31402:  map::__find_equal_key has undefined behavior.

Summary:
This patch fixes llvm.org/PR31402 by replacing `map::__find_equal_key` with `__tree::__find_equal`, which has already addressed the same undefined behavior.

Unfortunately I haven't been able to write a test case which causes the UBSAN diagnostic mentioned in the bug report. I can write tests which exercise the UB but for some reason they do not cause UBSAN to fail. Any help writing a test case would be appreciated.


Reviewers: mclow.lists, vsk, EricWF

Subscribers: cfe-commits

Differential Revision: https://reviews.llvm.org/D28131

git-svn-id: https://llvm.org/svn/llvm-project/libcxx/trunk@291087 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/__tree b/include/__tree
index a8842f3..dd32f70 100644
--- a/include/__tree
+++ b/include/__tree
@@ -1397,10 +1397,17 @@
     __node_base_pointer&
         __find_leaf(const_iterator __hint,
                     __parent_pointer& __parent, const key_type& __v);
+    // FIXME: Make this function const qualified. Unfortunetly doing so
+    // breaks existing code which uses non-const callable comparators.
     template <class _Key>
     __node_base_pointer&
         __find_equal(__parent_pointer& __parent, const _Key& __v);
     template <class _Key>
+    _LIBCPP_INLINE_VISIBILITY __node_base_pointer&
+    __find_equal(__parent_pointer& __parent, const _Key& __v) const {
+      return const_cast<__tree*>(this)->__find_equal(__parent, __v);
+    }
+    template <class _Key>
     __node_base_pointer&
         __find_equal(const_iterator __hint, __parent_pointer& __parent,
                      __node_base_pointer& __dummy,
diff --git a/include/map b/include/map
index c99359e..9555aad 100644
--- a/include/map
+++ b/include/map
@@ -533,7 +533,7 @@
         using _VSTD::swap;
         swap(comp, __y.comp);
     }
-    
+
 #if _LIBCPP_STD_VER > 11
     template <typename _K2>
     _LIBCPP_INLINE_VISIBILITY
@@ -730,7 +730,7 @@
     friend _LIBCPP_INLINE_VISIBILITY
     bool operator==(const __map_iterator& __x, const __map_iterator& __y)
         {return __x.__i_ == __y.__i_;}
-    friend 
+    friend
     _LIBCPP_INLINE_VISIBILITY
     bool operator!=(const __map_iterator& __x, const __map_iterator& __y)
         {return __x.__i_ != __y.__i_;}
@@ -895,7 +895,7 @@
 
 #if _LIBCPP_STD_VER > 11
     template <class _InputIterator>
-    _LIBCPP_INLINE_VISIBILITY 
+    _LIBCPP_INLINE_VISIBILITY
     map(_InputIterator __f, _InputIterator __l, const allocator_type& __a)
         : map(__f, __l, key_compare(), __a) {}
 #endif
@@ -961,7 +961,7 @@
         }
 
 #if _LIBCPP_STD_VER > 11
-    _LIBCPP_INLINE_VISIBILITY 
+    _LIBCPP_INLINE_VISIBILITY
     map(initializer_list<value_type> __il, const allocator_type& __a)
         : map(__il, key_compare(), __a) {}
 #endif
@@ -1297,6 +1297,7 @@
     typedef typename __base::__node_allocator          __node_allocator;
     typedef typename __base::__node_pointer            __node_pointer;
     typedef typename __base::__node_base_pointer       __node_base_pointer;
+    typedef typename __base::__parent_pointer          __parent_pointer;
 
     typedef __map_node_destructor<__node_allocator> _Dp;
     typedef unique_ptr<__node, _Dp> __node_holder;
@@ -1304,65 +1305,9 @@
 #ifdef _LIBCPP_CXX03_LANG
     __node_holder __construct_node_with_key(const key_type& __k);
 #endif
-
-    __node_base_pointer const&
-    __find_equal_key(__node_base_pointer& __parent, const key_type& __k) const;
-
-    _LIBCPP_INLINE_VISIBILITY
-    __node_base_pointer&
-    __find_equal_key(__node_base_pointer& __parent, const key_type& __k) {
-        map const* __const_this = this;
-        return const_cast<__node_base_pointer&>(
-            __const_this->__find_equal_key(__parent, __k));
-    }
 };
 
 
-// Find __k
-// Set __parent to parent of null leaf and
-//    return reference to null leaf iv __k does not exist.
-// If __k exists, set parent to node of __k and return reference to node of __k
-template <class _Key, class _Tp, class _Compare, class _Allocator>
-typename map<_Key, _Tp, _Compare, _Allocator>::__node_base_pointer const&
-map<_Key, _Tp, _Compare, _Allocator>::__find_equal_key(__node_base_pointer& __parent,
-                                                       const key_type& __k) const
-{
-    __node_pointer __nd = __tree_.__root();
-    if (__nd != nullptr)
-    {
-        while (true)
-        {
-            if (__tree_.value_comp().key_comp()(__k, __nd->__value_.__cc.first))
-            {
-                if (__nd->__left_ != nullptr)
-                    __nd = static_cast<__node_pointer>(__nd->__left_);
-                else
-                {
-                    __parent = static_cast<__node_base_pointer>(__nd);
-                    return __parent->__left_;
-                }
-            }
-            else if (__tree_.value_comp().key_comp()(__nd->__value_.__cc.first, __k))
-            {
-                if (__nd->__right_ != nullptr)
-                    __nd = static_cast<__node_pointer>(__nd->__right_);
-                else
-                {
-                    __parent = static_cast<__node_base_pointer>(__nd);
-                    return __parent->__right_;
-                }
-            }
-            else
-            {
-                __parent = static_cast<__node_base_pointer>(__nd);
-                return __parent;
-            }
-        }
-    }
-    __parent = static_cast<__node_base_pointer>(__tree_.__end_node());
-    return __parent->__left_;
-}
-
 #ifndef _LIBCPP_CXX03_LANG
 
 template <class _Key, class _Tp, class _Compare, class _Allocator>
@@ -1400,8 +1345,8 @@
 _Tp&
 map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k)
 {
-    __node_base_pointer __parent;
-    __node_base_pointer& __child = __find_equal_key(__parent, __k);
+    __parent_pointer __parent;
+    __node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
     __node_pointer __r = static_cast<__node_pointer>(__child);
     if (__child == nullptr)
     {
@@ -1440,8 +1385,8 @@
 _Tp&
 map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k)
 {
-    __node_base_pointer __parent;
-    __node_base_pointer& __child = __find_equal_key(__parent, __k);
+    __parent_pointer __parent;
+    __node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
 #ifndef _LIBCPP_NO_EXCEPTIONS
     if (__child == nullptr)
         throw out_of_range("map::at:  key not found");
@@ -1453,8 +1398,8 @@
 const _Tp&
 map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) const
 {
-    __node_base_pointer __parent;
-    __node_base_pointer __child = __find_equal_key(__parent, __k);
+    __parent_pointer __parent;
+    __node_base_pointer __child = __tree_.__find_equal(__parent, __k);
 #ifndef _LIBCPP_NO_EXCEPTIONS
     if (__child == nullptr)
         throw out_of_range("map::at:  key not found");
@@ -1621,7 +1566,7 @@
 
 #if _LIBCPP_STD_VER > 11
     template <class _InputIterator>
-    _LIBCPP_INLINE_VISIBILITY 
+    _LIBCPP_INLINE_VISIBILITY
     multimap(_InputIterator __f, _InputIterator __l, const allocator_type& __a)
         : multimap(__f, __l, key_compare(), __a) {}
 #endif
@@ -1688,7 +1633,7 @@
         }
 
 #if _LIBCPP_STD_VER > 11
-    _LIBCPP_INLINE_VISIBILITY 
+    _LIBCPP_INLINE_VISIBILITY
     multimap(initializer_list<value_type> __il, const allocator_type& __a)
         : multimap(__il, key_compare(), __a) {}
 #endif