#!/usr/bin/env python3
"""
N10 — THE BROADCAST: quantum Darwinism of geometry in the vacuum
==================================================================
OUTLIER THESIS: classical, objective spacetime = the massless vacuum
redundantly broadcasting the positions of masses. A gapped vacuum keeps
geometry private (quantum); a massless one publishes it.

KILL-GATE (written before the run):
  If the massless vacuum does NOT show broadcast structure — i.e. if
  the number of disjoint fragments that can each independently
  distinguish two defect positions is comparable between massless and
  gapped vacua — the thesis is dead.

ARM 1 (value redundancy): I(A : F_j) for disjoint 8-site fragments F_j
  along the chain. Broadcast = many fragments carry comparable MI.
ARM 2 (geometry redundancy — the real one): pin a defect at site a or
  a+2; per disjoint fragment compute the quantum relative entropy
  D_j = S(rho_F(a) || rho_F(b)) — the discrimination power of that
  fragment alone about WHERE the mass is. Redundancy R(delta) = number
  of disjoint fragments with D_j > delta.

Gaussian machinery: nu from eig(XP); EH via symplectic normal form
  S = X^(1/2) O D_nu^(-1/2):  h_qq = X^(-1/2) O (nu*eps) O^T X^(-1/2),
  h_pp = X^(1/2) O (eps/nu) O^T X^(1/2),  eps = ln((nu+1/2)/(nu-1/2)),
  lnZ = -Sum ln(2 sinh(eps/2));
  S(r1||r2) = -S(r1) + 1/2[Tr(h2_qq X1)+Tr(h2_pp P1)] + lnZ2.
Self-test S(rho||rho) = 0 is run as the numerical gate.
"""
import numpy as np, json, warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

N = 240
FRAG = 8

def covs(m, pin=None, M2=1.0):
    K = np.zeros((N, N))
    for i in range(N):
        K[i, i] = m*m + 2.0
        K[i, (i+1) % N] = -1.0
        K[(i+1) % N, i] = -1.0
    if pin is not None:
        K[pin, pin] += M2
    w2, U = np.linalg.eigh(K)
    w = np.sqrt(np.clip(w2, 1e-30, None))
    return (U*(0.5/w)) @ U.T, (U*(0.5*w)) @ U.T

def sqrtm_sym(A):
    e, V = np.linalg.eigh(A)
    return (V*np.sqrt(np.clip(e, 1e-300, None))) @ V.T

def block(M, R):
    return M[np.ix_(R, R)]

def normal_form(Xs, Ps):
    Xh = sqrtm_sym(Xs)
    M = Xh @ Ps @ Xh
    nu2, O = np.linalg.eigh(M)
    nu = np.sqrt(np.clip(nu2, 0.25 + 1e-13, None))
    return Xh, O, nu

def entropy_from_nu(nu):
    up, dn = nu + .5, nu - .5
    return float(np.sum(up*np.log(up)) -
                 np.sum(np.where(dn > 1e-14, dn*np.log(np.clip(dn, 1e-300, None)), 0)))

def eh_of(Xs, Ps):
    Xh, O, nu = normal_form(Xs, Ps)
    eps = np.log((nu + .5)/(nu - .5 + 1e-300))
    Xhi = np.linalg.inv(Xh)
    h_qq = Xhi @ (O*(nu*eps)) @ O.T @ Xhi
    h_pp = Xh @ (O*(eps/nu)) @ O.T @ Xh
    lnZ = float(-np.sum(np.log(2*np.sinh(np.clip(eps, 1e-300, None)/2))))
    return h_qq, h_pp, lnZ

def rel_entropy(X1, P1, X2, P2):
    _, _, nu1 = normal_form(X1, P1)
    S1 = entropy_from_nu(nu1)
    h_qq, h_pp, lnZ2 = eh_of(X2, P2)
    return float(-S1 + 0.5*(np.trace(h_qq @ X1) + np.trace(h_pp @ P1)) + lnZ2)

def MI(X, P, A, B):
    def S(R):
        _, _, nu = normal_form(block(X, R), block(P, R))
        return entropy_from_nu(nu)
    return S(A) + S(B) - S(A + B)

if __name__ == '__main__':
    # numerical gate: S(rho||rho) = 0 on a fragment
    Xv, Pv = covs(1e-3)
    F0 = list(range(40, 48))
    gate = rel_entropy(block(Xv, F0), block(Pv, F0), block(Xv, F0), block(Pv, F0))
    print(f"numerical gate S(rho||rho) = {gate:.2e}  (must be ~0)")
    assert abs(gate) < 1e-8

    out = {}
    for label, m in (("MASSLESS (m=1e-3)", 1e-3), ("GAPPED (m=0.3)", 0.3)):
        Xv, Pv = covs(m)
        a, b = N//2 - 1, N//2 + 1
        Xa, Pa = covs(m, pin=a)
        Xb, Pb = covs(m, pin=b)

        # disjoint fragments tiling the chain, excluding the defect zone
        frags = []
        i = 0
        while i + FRAG <= N:
            F = list(range(i, i + FRAG))
            if all(abs(((s - N//2 + N//2) % N) - 0) is not None for s in F):
                pass
            if not any(abs(s - a) <= 2 or abs(s - b) <= 2 for s in F):
                frags.append(F)
            i += FRAG

        A0 = [N//2 - 40]   # arm-1 probe site, away from defect zone
        arm1 = [MI(Xv, Pv, A0, F) for F in frags]

        arm2 = [rel_entropy(block(Xa, F), block(Pa, F),
                            block(Xb, F), block(Pb, F)) for F in frags]
        dists = [min(abs(F[0] + FRAG//2 - N//2), N - abs(F[0] + FRAG//2 - N//2))
                 for F in frags]

        for delta in (1e-6, 1e-8):
            R = sum(1 for d in arm2 if d > delta)
            print(f"{label}: R(delta={delta:g}) = {R}/{len(frags)} fragments "
                  f"can tell where the mass is")
        srt = np.argsort(dists)
        print(f"{label}: D_j vs distance: " +
              "  ".join(f"d={dists[k]}:{arm2[k]:.2e}" for k in srt[::4]))
        print(f"{label}: arm-1 MI per fragment (sorted by d): " +
              "  ".join(f"{arm1[k]:.4f}" for k in srt[::4]))
        out[label] = dict(dists=dists, arm1=arm1, arm2=arm2,
                          frags=len(frags))
    json.dump(out, open('/Users/antoine/agi/ledger/n10_results.json', 'w'), indent=1)
    print("-> n10_results.json")
