#!/usr/bin/env python3
"""
N4 (3+1D arm) — Newton from monogamy on the 18^3 lattice.
Weak pins, connected quantities. Perturbative expectation: both
E_int(r) and dI_conn(r) ~ r^-4 (mass-insertion / van-der-Waals sector).
The Ledger's claim is the EQUALITY of the two laws, not the exponent.
"""
import numpy as np, json, warnings, time
warnings.filterwarnings("ignore", category=RuntimeWarning)

L, m, M2 = 18, 0.02, 0.5
N = L ** 3
seps = (3, 4, 5, 6, 7, 8)

idx = lambda x, y, z: (x % L) * L * L + (y % L) * L + (z % L)

def build_K():
    K = np.zeros((N, N))
    for x in range(L):
        for y in range(L):
            for z in range(L):
                i = idx(x, y, z)
                K[i, i] = m * m + 6.0
                for d in ((1, 0, 0), (0, 1, 0), (0, 0, 1)):
                    j = idx(x + d[0], y + d[1], z + d[2])
                    K[i, j] -= 1.0
                    K[j, i] -= 1.0
    return K

def solve(K):
    w2, U = np.linalg.eigh(K)
    w = np.sqrt(np.clip(w2, 1e-30, None))
    return 0.5 * float(np.sum(w)), (U * (0.5 / w)) @ U.T, (U * (0.5 * w)) @ U.T

def S(X, P, R):
    ix = np.ix_(R, R)
    nu = np.sqrt(np.clip(np.linalg.eigvals(X[ix] @ P[ix]).real, 0.25, None))
    up, dn = nu + .5, nu - .5
    return float(np.sum(up * np.log(up)) -
                 np.sum(np.where(dn > 1e-12,
                                 dn * np.log(np.clip(dn, 1e-300, None)), 0)))

def MI(X, P, A, B):
    return S(X, P, A) + S(X, P, B) - S(X, P, A + B)

t0 = time.time()
print(f"N4 3D: L={L} ({N} sites), m={m}, M^2={M2}", flush=True)
K0 = build_K()
E0, Xv, Pv = solve(K0)
print(f"vacuum solved  t={time.time()-t0:.0f}s", flush=True)

c = L // 2
ia = idx(c, c, c)
K1 = K0.copy(); K1[ia, ia] += M2
E1, Xa, Pa = solve(K1)
print(f"single pin solved  t={time.time()-t0:.0f}s", flush=True)

rows = []
for r in seps:
    ib = idx(c + r, c, c)
    K2 = K0.copy(); K2[ia, ia] += M2; K2[ib, ib] += M2
    E2, X2, P2 = solve(K2)
    Eint = E2 - 2 * E1 + E0
    A, B = [ia], [ib]
    # pin-at-b state = mirror of pin-at-a (torus inversion through midpoint):
    # MI_b(A:B) = MI_a(A:B) with roles swapped -> equal by symmetry.
    mi_a = MI(Xa, Pa, A, B)
    dI = MI(X2, P2, A, B) - 2 * mi_a + MI(Xv, Pv, A, B)
    rows.append(dict(r=r, Eint=Eint, dI=dI))
    print(f"r={r}  E_int={Eint:.4e}  dI_conn={dI:+.4e}  "
          f"ratio={Eint/dI if abs(dI) > 1e-300 else float('nan'):.3e}  "
          f"t={time.time()-t0:.0f}s", flush=True)

rr = np.array([q['r'] for q in rows], float)
eE = np.abs([q['Eint'] for q in rows])
eI = np.abs([q['dI'] for q in rows])
pE = np.polyfit(np.log(rr), np.log(np.clip(eE, 1e-300, None)), 1)[0]
pI = np.polyfit(np.log(rr), np.log(np.clip(eI, 1e-300, None)), 1)[0]
print(f"\npower exponents:  E_int ~ r^{pE:+.2f}   dI_conn ~ r^{pI:+.2f}")
json.dump(dict(L=L, m=m, M2=M2, rows=rows, pE=float(pE), pI=float(pI)),
          open('/Users/antoine/agi/ledger/n4_3d.json', 'w'), indent=1)
print("-> n4_3d.json")
