Find the first element in a sorted array that is greater than the target
One way of thinking about this problem is to think about doing a binary search over a transformed version of the array, where the array has been modified by applying the function
f(x) = 1 if x > target
0 else
Now, the goal is to find the very first place that this function takes on the value 1
. We can do that using a binary search as follows:
int low = 0, high = numElems; // numElems is the size of the array i.e arr.size()
while (low != high) {
int mid = (low + high) / 2; // Or a fancy way to avoid int overflow
if (arr[mid] <= target) {
/* This index, and everything below it, must not be the first element
* greater than what we're looking for because this element is no greater
* than the element.
*/
low = mid + 1;
}
else {
/* This element is at least as large as the element, so anything after it can't
* be the first element that's at least as large.
*/
high = mid;
}
}
/* Now, low and high both point to the element in question. */
To see that this algorithm is correct, consider each comparison being made. If we find an element that's no greater than the target element, then it and everything below it can't possibly match, so there's no need to search that region. We can recursively search the right half. If we find an element that is larger than the element in question, then anything after it must also be larger, so they can't be the first element that's bigger and so we don't need to search them. The middle element is thus the last possible place it could be.
Note that on each iteration we drop off at least half the remaining elements from consideration. If the top branch executes, then the elements in the range [low, (low + high) / 2]
are all discarded, causing us to lose floor((low + high) / 2) - low + 1 >= (low + high) / 2 - low = (high - low) / 2 elements
.
If the bottom branch executes, then the elements in the range [(low + high) / 2 + 1, high]
are all discarded. This loses us high - floor(low + high) / 2 + 1 >= high - (low + high) / 2 = (high - low) / 2 elements
.
Consequently, we'll end up finding the first element greater than the target in O(lg n) iterations of this process.
Here's a trace of the algorithm running on the array 0 0 1 1 1 1
.
Initially, we have
0 0 1 1 1 1
L = 0 H = 6
So we compute mid = (0 + 6) / 2 = 3
, so we inspect the element at position 3
, which has value 1
. Since 1 > 0
, we set high = mid = 3
. We now have
0 0 1
L H
We compute mid = (0 + 3) / 2 = 1
, so we inspect element 1
. Since this has value 0 <= 0
, we set mid = low + 1 = 2
. We're now left with L = 2
and H = 3
:
0 0 1
L H
Now, we compute mid = (2 + 3) / 2 = 2
. The element at index 2
is 1
, and since 1
≥ 0
, we set H = mid = 2
, at which point we stop, and indeed we're looking at the first element greater than 0
.
You can use std::upper_bound
if the array is sorted (assuming n
is the size of array a[]
):
int* p = std::upper_bound( a, a + n, x );
if( p == a + n )
std::cout << "No element greater";
else
std::cout << "The first element greater is " << *p
<< " at position " << p - a;