Most accurate way to do a combined multiply-and-divide operation in 64-bit?
Since this is tagged Visual C++ I'll give a solution that abuses MSVC-specific intrinsics.
This example is fairly complicated. It's a highly simplified version of the same algorithm that is used by GMP and java.math.BigInteger
for large division.
Although I have a simpler algorithm in mind, it's probably about 30x slower.
This solution has the following constraints/behavior:
- It requires x64. It will not compile on x86.
- The quotient is not zero.
- The quotient saturates if it overflows 64-bits.
Note that this is for the unsigned integer case. It's trivial to build a wrapper around this to make it work for signed cases as well. This example should also produce correctly truncated results.
This code is not fully tested. However, it has passed all the tests cases that I've thrown at it.
(Even cases that I've intentionally constructed to try to break the algorithm.)
#include <intrin.h>
uint64_t muldiv2(uint64_t a, uint64_t b, uint64_t c){
// Normalize divisor
unsigned long shift;
_BitScanReverse64(&shift,c);
shift = 63 - shift;
c <<= shift;
// Multiply
a = _umul128(a,b,&b);
if (((b << shift) >> shift) != b){
cout << "Overflow" << endl;
return 0xffffffffffffffff;
}
b = __shiftleft128(a,b,shift);
a <<= shift;
uint32_t div;
uint32_t q0,q1;
uint64_t t0,t1;
// 1st Reduction
div = (uint32_t)(c >> 32);
t0 = b / div;
if (t0 > 0xffffffff)
t0 = 0xffffffff;
q1 = (uint32_t)t0;
while (1){
t0 = _umul128(c,(uint64_t)q1 << 32,&t1);
if (t1 < b || (t1 == b && t0 <= a))
break;
q1--;
// cout << "correction 0" << endl;
}
b -= t1;
if (t0 > a) b--;
a -= t0;
if (b > 0xffffffff){
cout << "Overflow" << endl;
return 0xffffffffffffffff;
}
// 2nd reduction
t0 = ((b << 32) | (a >> 32)) / div;
if (t0 > 0xffffffff)
t0 = 0xffffffff;
q0 = (uint32_t)t0;
while (1){
t0 = _umul128(c,q0,&t1);
if (t1 < b || (t1 == b && t0 <= a))
break;
q0--;
// cout << "correction 1" << endl;
}
// // (a - t0) gives the modulus.
// a -= t0;
return ((uint64_t)q1 << 32) | q0;
}
Note that if you don't need a perfectly truncated result, you can remove the last loop completely. If you do this, the answer will be no more than 2 larger than the correct quotient.
Test Cases:
cout << muldiv2(4984198405165151231,6132198419878046132,9156498145135109843) << endl;
cout << muldiv2(11540173641653250113, 10150593219136339683, 13592284235543989460) << endl;
cout << muldiv2(449033535071450778, 3155170653582908051, 4945421831474875872) << endl;
cout << muldiv2(303601908757, 829267376026, 659820219978) << endl;
cout << muldiv2(449033535071450778, 829267376026, 659820219978) << endl;
cout << muldiv2(1234568, 829267376026, 1) << endl;
cout << muldiv2(6991754535226557229, 7798003721120799096, 4923601287520449332) << endl;
cout << muldiv2(9223372036854775808, 2147483648, 18446744073709551615) << endl;
cout << muldiv2(9223372032559808512, 9223372036854775807, 9223372036854775807) << endl;
cout << muldiv2(9223372032559808512, 9223372036854775807, 12) << endl;
cout << muldiv2(18446744073709551615, 18446744073709551615, 9223372036854775808) << endl;
Output:
3337967539561099935
8618095846487663363
286482625873293138
381569328444
564348969767547451
1023786965885666768
11073546515850664288
1073741824
9223372032559808512
Overflow
18446744073709551615
Overflow
18446744073709551615
You just need 64 bits integers. There are some redundant operations but that allows to use 10 as base and step in the debugger.
uint64_t const base = 1ULL<<32;
uint64_t const maxdiv = (base-1)*base + (base-1);
uint64_t multdiv(uint64_t a, uint64_t b, uint64_t c)
{
// First get the easy thing
uint64_t res = (a/c) * b + (a%c) * (b/c);
a %= c;
b %= c;
// Are we done?
if (a == 0 || b == 0)
return res;
// Is it easy to compute what remain to be added?
if (c < base)
return res + (a*b/c);
// Now 0 < a < c, 0 < b < c, c >= 1ULL
// Normalize
uint64_t norm = maxdiv/c;
c *= norm;
a *= norm;
// split into 2 digits
uint64_t ah = a / base, al = a % base;
uint64_t bh = b / base, bl = b % base;
uint64_t ch = c / base, cl = c % base;
// compute the product
uint64_t p0 = al*bl;
uint64_t p1 = p0 / base + al*bh;
p0 %= base;
uint64_t p2 = p1 / base + ah*bh;
p1 = (p1 % base) + ah * bl;
p2 += p1 / base;
p1 %= base;
// p2 holds 2 digits, p1 and p0 one
// first digit is easy, not null only in case of overflow
uint64_t q2 = p2 / c;
p2 = p2 % c;
// second digit, estimate
uint64_t q1 = p2 / ch;
// and now adjust
uint64_t rhat = p2 % ch;
// the loop can be unrolled, it will be executed at most twice for
// even bases -- three times for odd one -- due to the normalisation above
while (q1 >= base || (rhat < base && q1*cl > rhat*base+p1)) {
q1--;
rhat += ch;
}
// subtract
p1 = ((p2 % base) * base + p1) - q1 * cl;
p2 = (p2 / base * base + p1 / base) - q1 * ch;
p1 = p1 % base + (p2 % base) * base;
// now p1 hold 2 digits, p0 one and p2 is to be ignored
uint64_t q0 = p1 / ch;
rhat = p1 % ch;
while (q0 >= base || (rhat < base && q0*cl > rhat*base+p0)) {
q0--;
rhat += ch;
}
// we don't need to do the subtraction (needed only to get the remainder,
// in which case we have to divide it by norm)
return res + q0 + q1 * base; // + q2 *base*base
}