Fast, branchless unsigned int absolute difference
If you are targeting a system with SSE instructions you could use that for nice performance boost. I tested this against other posted methods and it seems to be the fastest approach.
Example results for diffing large amount of values:
diff0: 188.020679 ms // branching
diff1: 118.934970 ms // max min
diff2: 97.087710 ms // branchless mul add
diff3: 54.495269 ms // branchless signed
diff4: 31.159628 ms // sse
diff5: 30.855885 ms // sse v2
My full test code below. I used SSE2 instructions, which are widely available in x86ish CPUs nowadays, through SSE intrinsics, which should be quite portable (MSVC, GCC, Clang, Intel compilers, etc.).
Notes:
- Effectively this calculates max then min and then subtracts but does 16 values at once with each instruction.
- Unrolling it in
diff5
seems to have little effect, but possibly can be tweaked. - The fallback for last 15 or less values currently uses the signed trick method in a loop, but it could possibly be sped up further with unrolling and/or SSE.
- The functions themselves are quite simple so they should be easily portable to anything with SSE intrinsics or asm.
- I used Windows specific timing functions because
std::chrono::high_resolution_clock
has low precision in MSVC implementation, sorry for that, and for the dirty mix of C/C++ test code. - After timing the performance, the results are tested against reference branching implementation so they should be correct.
Please leave a comment if you have any questions/suggestions regarding the code or this approach in general.
#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <random>
#include <algorithm>
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>
#include <emmintrin.h> // sse2
// branching
void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
}
}
// max min
void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
}
}
// branchless mul add
void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
}
}
// branchless signed
void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
// sse
void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
auto pA = reinterpret_cast<const __m128i*>(a);
auto pB = reinterpret_cast<const __m128i*>(b);
auto pRes = reinterpret_cast<__m128i*>(res);
std::size_t i = 0;
for (std::size_t j = n / 16; j--; i++) {
__m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
__m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
_mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
}
for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
// sse v2
void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
auto pA = reinterpret_cast<const __m128i*>(a);
auto pB = reinterpret_cast<const __m128i*>(b);
auto pRes = reinterpret_cast<__m128i*>(res);
std::size_t i = 0;
const std::size_t UNROLL = 2;
for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
__m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
__m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
__m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
__m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
_mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
_mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
}
for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
__m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
__m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
_mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
}
for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
int main() {
const std::size_t ALIGN = 16; // sse requires 16 bit align
const std::size_t N = 10 * 1024 * 1024 * 3;
auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
{ // fill with random values
std::mt19937 engine(std::random_device{}());
std::uniform_int<std::uint8_t> distribution(0, 255);
for (std::size_t i = 0; i < N; i++) {
a[i] = distribution(engine);
b[i] = distribution(engine);
}
}
auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results
LARGE_INTEGER f, t0, t1;
QueryPerformanceFrequency(&f);
QueryPerformanceCounter(&t0);
diff0(a, b, res0, N);
QueryPerformanceCounter(&t1);
printf("diff0: %.6f ms\n",
static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);
#define TEST(diffX)\
QueryPerformanceCounter(&t0);\
diffX(a, b, resX, N);\
QueryPerformanceCounter(&t1);\
printf("%s: %.6f ms\n", #diffX,\
static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
for (std::size_t i = 0; i < N; i++) {\
if (resX[i] != res0[i]) {\
printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
a[i], b[i], resX[i], res0[i]);\
break;\
}\
}
TEST(diff1);
TEST(diff2);
TEST(diff3);
TEST(diff4);
TEST(diff5);
_mm_free(a);
_mm_free(b);
_mm_free(res0);
_mm_free(resX);
getc(stdin);
return 0;
}
Edit: Changing my answer, I had optimizations misconfigured for this.
I set up a quick test bed for this in C, and I'm finding that
a - b + (a < b) * ((b - a) << 1);
is a hair better, at least in my setup. The advantage of my approach is to eliminates a comparison. Your version implicitly handles a - b == 0
like its a separate case, when this is not necessary.
My test with yours takes
- Your implementation: 371ms
- This implementation: 324ms
- Speedup: 14%
I tried an approach with a non-branching absolute value, and the results were better. Note that whether the inputs or outputs are considered signed or not by the compiler is irrelevant. It loops around large unsigned values, but since it only has to work on small values (as stated by the question), it should be sufficient.
s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
- Your Implementation: 371ms
- This implementation: 241ms
- Speedup: 53%!
Well, I tried to benchmark a bit. I use Criterion for the benchmarks, because it does proper significance tests. I also use QuickCheck here to ensure that all methods return the same results.
I compiled with GHC 7.6.3 (so I couldn't include your primops function, unfortunately) and with -O3
:
ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff
Primarily we can see the difference between a naive implementation and a bit of fiddeling:
absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b
absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
mask = unsafeShiftR v 63
Output:
benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....
benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...
I use the absolute integer value trick from "Bit Twiddling Hacks here". Unfortunately we need casts, I don't think that it is possible to solve the problem well in the domain of Word8
alone, but it seems sensible to use the native integer type anyway (there's definitely no need to create a heap object though ).
It doesn't really look like a large difference, but my test setup is also not perfect: I am mapping the function over a large list of random values to rule out branch prediction making the branching version seem more efficient than it is. This causes thunks to build up in memory, which could influence the timings a lot. When we subtract the constant overhead for maintaining the list, we could well see a lot more than the 20% speedup.
The generated assembly is actually pretty good (this is an inlined version of the function):
.Lc4BB:
leaq 7(%rbx),%rax
movq 8(%rbp),%rbx
subq (%rax),%rbx
movq %rbx,%rax
sarq $63,%rax
movq $base_GHCziInt_I64zh_con_info,-8(%r12)
addq %rax,%rbx
xorq %rax,%rbx
movq %rbx,0(%r12)
leaq -7(%r12),%rbx
movq $s4z0_info,8(%rbp)
1 subtraction, 1 addition, 1 right-shift, 1 xor and no branch, as expected. Using the LLVM backend doesn't improve the runtime noticably.
Hope this is useful if you want to try out more stuff.
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce
import Test.QuickCheck hiding ((.&.))
import Criterion.Main
absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b
absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b
absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b
absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
where v = a - b
mask = unsafeShiftR v 15
absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
!mask = unsafeShiftR v 63
absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a
{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
{-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}
e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum
prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
where x' = e2e x
y' = e2e y
check = quickCheck prop_same1
>> quickCheck prop_same2
instance (Random x, Random y) => Random (x, y) where
random gen1 =
let (x, gen2) = random gen1
(y, gen3) = random gen2
in ((x,y),gen3)
main =
do check
!pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
defaultMain
[ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
, bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
, bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
]
, bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
, bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
]
{-, bgroup "absdiff_Int" [ bench "1" $ whnf (absdiff1_int 13) 14-}
{-, bench "2" $ whnf (absdiff3_int 13) 14-}
{-]-}
]