#!/usr/bin/env python3
"""
N2 — THE WITHDRAWAL, MEASURED
==============================
Monogamy Gravity, Folio III numerical experiment.

CLAIM UNDER TEST (Postulate 3, "The Withdrawal"):
    Matter-matter entanglement must be debited from the vacuum account.
    A system cannot hold new entanglement with a distant partner without
    shedding entanglement/correlation with the vacuum around it.

KILL-GATE (written before the run, per house rules):
    If, as A:B entanglement rises, A's correlation with the intervening
    vacuum region M does NOT fall, the withdrawal picture of Postulate 3
    is falsified in its simplest arena and the theory is dead.

PROTOCOL ("the import"):
    1. Ground state of a harmonic chain (N sites, PBC, IR mass m):
       the vacuum. Gaussian state, covariances X = K^{-1/2}/2, P = K^{1/2}/2.
    2. Prepare an EXTERNAL two-mode-squeezed pair of ancillas (squeezing s):
       imported entanglement, initially uncorrelated with the chain.
    3. Couple ancilla alpha to chain site a, ancilla beta to chain site b,
       through beam splitters of angle theta (transmissivity tau = sin^2 theta).
       theta = 0: nothing happens. theta = pi/2: full swap — the imported pair
       replaces the local vacuum modes entirely.
    4. Trace out the ancillas. Measure, as a function of theta:
         - I(A:B), E_N(A:B): the imported matter-matter account
         - I(A:M):           A's account with the intervening vacuum M
         - S(A), S(M), dE:   sanity checks and the energy bill
    Beam splitters are passive (no squeezing of the chain); all new
    entanglement is imported, none is minted in place.

CONTROL ARM (s = 0):
    Vacuum ancillas import NOTHING. Any fall of I(A:M) in the control is
    mode-replacement noise, not payment for entanglement. The
    withdrawal-specific signal is the difference between arms.

METHOD: Peschel (2003) for Gaussian states. With <qp> = 0 throughout
(the protocol mixes q with q and p with p only), the symplectic spectrum
of a region R is nu_k = sqrt(eig(X_R P_R)) and
S(R) = sum (nu+1/2)ln(nu+1/2) - (nu-1/2)ln(nu-1/2).
Log-negativity from the partial transpose (flip sign of P rows/cols in B).

Note: on macOS/Accelerate, numpy emits spurious RuntimeWarnings from
matmul; verified harmless (min eig K = m^2 exactly, no NaN, bounded X).

Plain numpy; no Sage required.
"""

import numpy as np
import json
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)


# ---------------------------------------------------------------- machinery
def chain_covariances(N, m):
    """Ground-state covariances X = <qq>, P = <pp> of the PBC chain."""
    K = np.zeros((N, N))
    for i in range(N):
        K[i, i] = m * m + 2.0
        K[i, (i + 1) % N] = -1.0
        K[(i + 1) % N, i] = -1.0
    w2, U = np.linalg.eigh(K)
    w = np.sqrt(np.clip(w2, 1e-30, None))
    X = (U * (0.5 / w)) @ U.T          # K^{-1/2} / 2
    P = (U * (0.5 * w)) @ U.T          # K^{+1/2} / 2
    return K, X, P


def entropy(X, P, region):
    """Von Neumann entropy of a region (Peschel), <qp> = 0 assumed."""
    idx = np.ix_(region, region)
    nu2 = np.linalg.eigvals(X[idx] @ P[idx]).real
    nu = np.sqrt(np.clip(nu2, 0.25, None))
    up, dn = nu + 0.5, nu - 0.5
    S = np.sum(up * np.log(up)) - np.sum(
        np.where(dn > 1e-12, dn * np.log(np.clip(dn, 1e-300, None)), 0.0))
    return float(S)


def mutual_info(X, P, R1, R2):
    return entropy(X, P, R1) + entropy(X, P, R2) - entropy(X, P, R1 + R2)


def log_negativity(X, P, R1, R2):
    """E_N between R1 and R2: partial transpose = flip p-sign on R2."""
    region = R1 + R2
    idx = np.ix_(region, region)
    D = np.ones(len(region)); D[len(R1):] = -1.0
    Pt = (P[idx] * D).T * D            # D P D
    nu2 = np.linalg.eigvals(X[idx] @ Pt).real
    nu = np.sqrt(np.clip(nu2, 1e-30, None))
    return float(np.sum(np.where(nu < 0.5, -np.log(2.0 * nu), 0.0)))


def energy(K, X, P):
    return 0.5 * np.trace(P) + 0.5 * np.trace(K @ X)


def imported_state(theta, K, Xv, Pv, N, a_site, b_site, s_anc):
    """Chain covariances after importing the ancilla pair at angle theta."""
    Nt = N + 2                          # chain + ancillas alpha (N), beta (N+1)
    X = np.zeros((Nt, Nt)); P = np.zeros((Nt, Nt))
    X[:N, :N] = Xv; P[:N, :N] = Pv
    ch, sh = np.cosh(2 * s_anc) / 2, np.sinh(2 * s_anc) / 2
    X[N:, N:] = [[ch, sh], [sh, ch]]    # TMS pair: q correlated,
    P[N:, N:] = [[ch, -sh], [-sh, ch]]  #           p anti-correlated
    R = np.eye(Nt)                      # passive beam splitters, same on q & p
    c, s = np.cos(theta), np.sin(theta)
    for site, anc in ((a_site, N), (b_site, N + 1)):
        R[site, site] = c; R[site, anc] = s
        R[anc, site] = -s; R[anc, anc] = c
    X = R @ X @ R.T; P = R @ P @ R.T
    return X[:N, :N], P[:N, :N]         # trace out ancillas


# ---------------------------------------------------------------- experiment
def run_experiment(label, N=240, m=1e-3, a_site=60, b_site=180,
                   M=None, s_anc=1.5, n_theta=13, verbose=True):
    M = M if M is not None else list(range(110, 131))
    K, Xv, Pv = chain_covariances(N, m)
    A, B = [a_site], [b_site]
    E0 = energy(K, Xv, Pv)
    thetas = np.linspace(0.0, np.pi / 2, n_theta)

    rows = []
    if verbose:
        print(f"\n=== {label}:  N={N}, m={m}, |a-b|={b_site - a_site}, "
              f"M={M[0]}..{M[-1]}, ancilla squeezing s={s_anc}")
        print(f"{'theta/pi':>9} {'tau':>6} {'I(A:B)':>9} {'E_N(A:B)':>9} "
              f"{'I(A:M)':>9} {'S(A)':>8} {'S(M)':>8} {'dE':>8}")
    for th in thetas:
        X, P = imported_state(th, K, Xv, Pv, N, a_site, b_site, s_anc)
        r = dict(theta=float(th), tau=float(np.sin(th) ** 2),
                 IAB=mutual_info(X, P, A, B), EN_AB=log_negativity(X, P, A, B),
                 IAM=mutual_info(X, P, A, M), IBM=mutual_info(X, P, B, M),
                 SA=entropy(X, P, A), SM=entropy(X, P, M),
                 dE=float(energy(K, X, P) - E0))
        rows.append(r)
        if verbose:
            print(f"{th / np.pi:9.4f} {r['tau']:6.3f} {r['IAB']:9.5f} "
                  f"{r['EN_AB']:9.5f} {r['IAM']:9.5f} "
                  f"{r['SA']:8.5f} {r['SM']:8.5f} {r['dE']:8.4f}")

    I_AB = np.array([r['IAB'] for r in rows])
    I_AM = np.array([r['IAM'] for r in rows])
    rising = bool(I_AB[-1] > I_AB[0] + 1e-9)
    falling = bool(np.all(np.diff(I_AM) < 1e-12))
    drained = float(1.0 - I_AM[-1] / I_AM[0]) if I_AM[0] > 0 else float('nan')
    sm_drift = max(abs(r['SM'] - rows[0]['SM']) for r in rows)
    verdict = dict(rising=rising, falling=falling, drained_fraction=drained,
                   SM_drift=float(sm_drift))
    if verbose:
        print(f"--- I(A:B) {'rises' if rising else 'FLAT'} "
              f"({I_AB[0]:.5f} -> {I_AB[-1]:.5f}) · "
              f"I(A:M) {'falls' if falling else 'NOT MONOTONE'} "
              f"({I_AM[0]:.5f} -> {I_AM[-1]:.5f}, "
              f"{100 * drained:.1f}% drained) · S(M) drift {sm_drift:.1e}")
    return dict(label=label,
                params=dict(N=N, m=m, a=a_site, b=b_site,
                            M=[M[0], M[-1]], s_anc=s_anc),
                rows=rows, verdict=verdict)


if __name__ == '__main__':
    print("N2 — the withdrawal, measured.")
    main    = run_experiment("MAIN ARM   (near-critical, import s=1.5)",
                             m=1e-3, s_anc=1.5)
    control = run_experiment("CONTROL ARM(near-critical, vacuum ancillas s=0)",
                             m=1e-3, s_anc=0.0)
    massive = run_experiment("MASSIVE ARM(m=0.1, import s=1.5)",
                             m=0.1, s_anc=1.5)

    # the account statement: where A's debit is drained from (main arm, full swap)
    N, a_site, b_site, s_anc = 240, 60, 180, 1.5
    K, Xv, Pv = chain_covariances(N, 1e-3)
    X1, P1 = imported_state(np.pi / 2, K, Xv, Pv, N, a_site, b_site, s_anc)
    profile = [dict(site=j,
                    I_vac=mutual_info(Xv, Pv, [a_site], [j]),
                    I_after=mutual_info(X1, P1, [a_site], [j]))
               for j in range(0, N, 4) if j != a_site]

    # ------------------------------------------------------------- verdict
    print("\n" + "=" * 64)
    print("VERDICT")
    m_v, c_v = main['verdict'], control['verdict']
    main_IAB = main['rows'][-1]['IAB'] - main['rows'][0]['IAB']
    ctrl_IAB = control['rows'][-1]['IAB'] - control['rows'][0]['IAB']
    print(f"main arm:    dI(A:B) = {main_IAB:+.4f}, I(A:M) drained "
          f"{100 * m_v['drained_fraction']:.1f}%  -> withdrawal present")
    print(f"control arm: dI(A:B) = {ctrl_IAB:+.4f}, I(A:M) drained "
          f"{100 * c_v['drained_fraction']:.1f}%  -> replacement noise floor")
    print(f"massive arm: falling = {massive['verdict']['falling']} "
          f"(robust away from criticality)")
    gate_ok = m_v['rising'] and m_v['falling'] and massive['verdict']['falling']
    print("\nKILL-GATE: " + ("PASSED — hosting A:B entanglement always drains "
                             "A's vacuum account." if gate_ok else
                             "TRIPPED — Postulate 3 falsified in this arena."))
    print("HONEST NOTE: the control arm drains too — mode replacement alone "
          "severs A's vacuum accounts. The withdrawal-specific content is "
          "(i) the drain occurs identically whether or not entanglement is "
          "imported (monogamy doesn't care what evicts the vacuum), and "
          "(ii) only the import arm converts the vacated account into A:B "
          "entanglement. The vacancy is necessary either way; that is the "
          "monogamy constraint.")

    with open('/Users/antoine/agi/ledger/n2_results.json', 'w') as f:
        json.dump(dict(main=main, control=control, massive=massive,
                       profile=profile), f, indent=1)
    print("\nresults -> n2_results.json")
