Triangles in a tetrahedron
Rust, \$A(1), \dotsc, A(1375)\$ in 10 minutes
Unofficial score on Ryzen 7 1800X (8 cores/16 threads). Build with cargo build --release
and run with time target/release/tetrahedron n
to compute \$A(1), \dotsc, A(n)\$.
This runs in \$O(n^4)\$ time. (So to estimate a good value of \$n\$ for your CPU, first time it for some smaller \$n\$, then multiply that \$n\$ by a factor of \$\left(\frac{600\,\mathrm{s}}{t}\right)^{1/4}\$.)
How it works
Any triangle that fits inside a tetrahedron of minimal side \$k \le n\$ may be translated inside a tetrahedron of side \$n\$ in exactly \$\binom{n - k + 3}{3}\$ ways. This means we only need to find it in one position, leaving six free parameters. Two of these parameters may be computed from the other four (up to a sign choice) if the triangle is to be equilateral, so we only need to loop over an \$O(n^4)\$ space.
src/main.rs
use rayon::prelude::*;
fn get_counts(n: i64, a0: i64) -> Vec<i64> {
let mut c = vec![0; n as usize];
let a0a0 = a0 * a0;
for a1 in if a0 == 0 { 1 } else { -n + 1 }..n {
let d = a0a0 + a1 * a1;
let m = n - a0.abs() - a1.abs();
for a2 in if m > 0 {
-n + 2 - (m & 1)..n
} else {
-n - m + 2..n + m
}
.step_by(2)
{
let d = d + a2 * a2;
let r = 2 * (a0a0 - d);
if r == 0 {
continue;
}
for b0 in a0..n {
let pp = d * (3 * d - 4 * (a0a0 + b0 * (b0 - a0)));
if pp < 0 {
break;
}
let p = (pp as f64).sqrt() as i64;
if p * p != pp {
continue;
}
let q = 2 * a0 * b0 - d;
let mut check = |p: i64| {
let b1r = p * a2 + q * a1;
if b1r % r != 0 {
return;
}
let b1 = b1r / r;
let b2r = -p * a1 + q * a2;
if b2r % r != 0 {
return;
}
let b2 = b2r / r;
if (b0, b1, b2) <= (a0, a1, a2) || b0 + b1 + b2 & 1 != 0 {
return;
}
let t = 0.max(a0 + a1 + a2).max(b0 + b1 + b2)
+ 0.max(-a0 - a1 + a2).max(-b0 - b1 + b2)
+ 0.max(-a0 + a1 - a2).max(-b0 + b1 - b2)
+ 0.max(a0 - a1 - a2).max(b0 - b1 - b2);
if t >= 2 * n {
return;
}
c[t as usize / 2] += 1;
};
check(p);
if p != 0 {
check(-p);
}
}
}
}
c
}
fn add_vec(c0: Vec<i64>, c1: Vec<i64>) -> Vec<i64> {
c0.into_iter().zip(c1).map(|(x0, x1)| x0 + x1).collect()
}
fn main() {
let n = std::env::args().skip(1).next().expect("missing argument");
let n = n.parse().expect("not an integer");
let counts = (0..n)
.into_par_iter()
.map(|a0| get_counts(n, a0))
.reduce(|| vec![0; n as usize], add_vec);
let (mut d0, mut d1, mut d2, mut d3) = (0, 0, 0, 0);
for (i, x) in (1..).zip(counts) {
d3 += x;
d2 += d3;
d1 += d2;
d0 += d1;
println!("{} {}", i, d0);
}
}
Cargo.toml
[package]
name = "tetrahedron"
version = "0.1.0"
authors = ["Anders Kaseorg <[email protected]>"]
edition = "2018"
[dependencies]
rayon = "1.3.0"
Try it online! (Parallelism removed for TIO.)
C++, all up to 40 in ten minutes
Runs in \$O(n^9)\$ time complexity (fortunately, it seems to be divided by at least 36 and it's also multi-threaded). I tested on Ubuntu 19.10 on AMD Ryzen 5 2600 (12 threads), tested with clang++ -Ofast -march=native -flto -no-pie -fopenmp
and ran with timeout 600 ./a.out
.
Code:
//#define _GLIBCXX_DEBUG
#include <iostream>
#include <cstring>
#include <complex>
#include <streambuf>
#include <bitset>
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cmath>
#include <climits>
#include <random>
#include <set>
#include <list>
#include <map>
#include <deque>
#include <stack>
#include <queue>
#include <string>
#include <iomanip>
#include <unordered_set>
#include <thread>
struct pt3
{
short x, y, z;
bool operator < (const pt3& rhs) const
{
return std::tie(x, y, z) < std::tie(rhs.x, rhs.y, rhs.z);
}
pt3 operator - (const pt3& rhs) const
{
return {short(x - rhs.x), short(y - rhs.y), short(z - rhs.z)};
}
int sqdist() const
{
return int(x)*int(x) + int(y)*int(y) + int(z)*int(z);
}
};
int solve(int n)
{
//the several lines below took a lot of tinkering-until-it-works
std::set<pt3> pt3s;
for(int i = 0; i < n; i++)
for(int j = 0; j < n; j++)
for(int k = 0; k < n; k++)
{
if(i+j+k >= n) continue;
pt3 pt { short(i+j), short(j+k), short(i+k) };
pt3s.insert(pt);
}
std::vector<pt3> points; //copy into a vector, they're much faster for this
for(pt3 el : pt3s) points.push_back(el);
//printf("n=%d, ps=%d\n", n, points.size());
int64_t ans = 0;
#pragma omp parallel for schedule(guided) reduction(+:ans)
for(int i = 0; i < points.size(); i++)
for(int j = i + 1; j < points.size(); j++)
for(int k = j + 1; k < points.size(); k++)
{
pt3 a = points[i], b = points[j], c = points[k];
//consider pairwise distances
pt3 p1 = a-b, p2 = a-c, p3 = b-c; //33% of all time
int d1 = p1.sqdist(), d2 = p2.sqdist(), d3 = p3.sqdist(); //another 33% of all time
if(d1 != d2 || d1 != d3) continue;
ans++;
//printf("%d %d %d; %d %d %d; %d %d %d\n", p1.x, p1.y, p1.z, p2.x, p2.y, p2.z, p3.x, p3.y, p3.z);
}
return ans;
}
int main()
{
for(int i = 1;; i++)
{
int ans = solve(i);
printf("n=%d: %d\n", i, ans);
}
}
Output:
n=1: 0
n=2: 4
n=3: 24
n=4: 84
n=5: 224
n=6: 516
n=7: 1068
n=8: 2016
n=9: 3528
n=10: 5832
n=11: 9256
n=12: 14208
n=13: 21180
n=14: 30728
n=15: 43488
n=16: 60192
n=17: 81660
n=18: 108828
n=19: 142764
n=20: 184708
n=21: 236088
n=22: 298476
n=23: 373652
n=24: 463524
n=25: 570228
n=26: 696012
n=27: 843312
n=28: 1014720
n=29: 1213096
n=30: 1441512
n=31: 1703352
n=32: 2002196
n=33: 2341848
n=34: 2726400
n=35: 3160272
n=36: 3648180
n=37: 4195164
n=38: 4806496
n=39: 5487792
n=40: 6244992
JavaScript (ES7), a(30) in ~50 seconds1
1: when run locally on my laptop
A very simple algorithm.
function count(n) {
const r0 = (8 / 3) ** 0.5, r1 = 2 / 3, r2 = 3 ** 0.5;
let cnt = 0;
for(let z1 = 0; z1 < n; z1++)
for(let Z1 = z1 * r0,
y1 = 0; y1 <= z1; y1++)
for(let Y1 = (y1 - z1 * r1) * r2,
x1 = 0; x1 <= y1; x1++)
for(let X1 = 2 * x1 - y1,
z2 = z1; z2 < n; z2++)
for(let Z2 = z2 * r0,
y2 = z2 > z1 ? 0 : y1; y2 <= z2; y2++)
for(let Y2 = (y2 - z2 * r1) * r2,
x2 = z2 > z1 || y2 > y1 ? 0 : x1 + 1; x2 <= y2; x2++)
for(let X2 = 2 * x2 - y2,
S1 = (X1 - X2) ** 2 + (Y1 - Y2) ** 2 + (Z1 - Z2) ** 2,
z3 = z2; z3 < n; z3++)
for(let Z3 = z3 * r0,
y3 = z3 > z2 ? 0 : y2; y3 <= z3; y3++)
for(let Y3 = (y3 - z3 * r1) * r2,
x3 = z3 > z2 || y3 > y2 ? 0 : x2 + 1; x3 <= y3; x3++) {
let X3 = 2 * x3 - y3,
S2 = (X1 - X3) ** 2 + (Y1 - Y3) ** 2 + (Z1 - Z3) ** 2;
if(Math.abs(S1 - S2) < 1e-9) {
let S3 = (X2 - X3) ** 2 + (Y2 - Y3) ** 2 + (Z2 - Z3) ** 2;
if(Math.abs(S1 - S3) < 1e-9) {
cnt++;
}
}
}
return cnt;
}
Try it online!