#!/usr/bin/env python3
"""
N6 — THE LANDAUER RATE (first open-system run)
================================================
Monogamy Gravity — the hardest gate: the fee's rate per nat, measured in
genuine dynamics rather than static spectra.

CLAIM UNDER TEST (Postulate 4, rate form):
    A transaction mediated through the vacuum across a causal span d is
    billed at that diamond's temperature: fee-per-nat ~ T_d ~ 1/d.
    Operationally: two probes entangling THROUGH the vacuum at
    separation d must accumulate mixedness (the fee, D = S(ab)) in
    proportion to the entanglement transacted (E), with a price that
    FALLS as the diamond grows:
        F(d) = D/E  ~  1/d   (larger rooms charge less per nat)

GATE (written before the run):
    If the fee per nat F(d) does not decrease with d — flat or rising —
    the diamond-pricing of the fee fails in its first dynamical arena.

PROTOCOL:
    Chain vacuum (N sites, PBC, mass m) + two probe oscillators
    (frequency W, in-band) attached by q-q coupling eps at sites
    separated by d. Everything quadratic: the full covariance evolves
    exactly, Gamma(t) = e^{At} Gamma(0) e^{A^T t}, A = [[0,1],[-K,0]].
    Initial state: chain vacuum x probe ground states (product) -- the
    coupled system then evolves; the chain mediates an effective a-b
    interaction (the transaction) while entangling with the probes
    (the fee).
    Measured vs time: E_N(a:b) (entanglement transacted),
    S(ab) (fee paid), MI(a:b). At the first entanglement peak t*(d):
    F(d) = S(ab)/E_N. Sweep d.

Exact Gaussian dynamics; plain numpy + scipy.linalg.expm.
"""

import numpy as np
import json
import warnings
from scipy.linalg import expm
warnings.filterwarnings("ignore", category=RuntimeWarning)

N, m = 300, 0.02
W2 = 1.0          # probe frequency^2 (in band)
eps = 0.30        # probe-chain coupling
T_MAX, NT = 80.0, 81


def build(d):
    """K matrix for chain + 2 probes attached at separation d."""
    n = N + 2
    K = np.zeros((n, n))
    for i in range(N):
        K[i, i] = m * m + 2.0
        K[i, (i + 1) % N] = -1.0
        K[(i + 1) % N, i] = -1.0
    a_site, b_site = N // 2 - d // 2, N // 2 + (d + 1) // 2
    ia, ib = N, N + 1                      # probe indices
    K[ia, ia] = W2; K[ib, ib] = W2
    K[ia, a_site] = K[a_site, ia] = eps
    K[ib, b_site] = K[b_site, ib] = eps
    return K, ia, ib


def initial_covariance(n):
    """Product state: chain vacuum x probe grounds. Basis (q..., p...)."""
    Kc = np.zeros((N, N))
    for i in range(N):
        Kc[i, i] = m * m + 2.0
        Kc[i, (i + 1) % N] = -1.0
        Kc[(i + 1) % N, i] = -1.0
    w2, U = np.linalg.eigh(Kc)
    w = np.sqrt(np.clip(w2, 1e-30, None))
    Xc = (U * (0.5 / w)) @ U.T
    Pc = (U * (0.5 * w)) @ U.T
    X = np.zeros((n, n)); P = np.zeros((n, n))
    X[:N, :N] = Xc; P[:N, :N] = Pc
    Wp = np.sqrt(W2)
    for k in (N, N + 1):
        X[k, k] = 0.5 / Wp; P[k, k] = 0.5 * Wp
    G = np.zeros((2 * n, 2 * n))
    G[:n, :n] = X; G[n:, n:] = P
    return G


def symp_eigs(G):
    """Symplectic eigenvalues of a (2k x 2k) covariance in (q...,p...) basis."""
    k = G.shape[0] // 2
    Om = np.zeros((2 * k, 2 * k))
    Om[:k, k:] = np.eye(k); Om[k:, :k] = -np.eye(k)
    ev = np.linalg.eigvals(Om @ G)
    nu = np.sort(np.abs(ev.imag))[k:]      # k positive symplectic eigenvalues
    return np.clip(nu, 0.5, None)


def entropy_full(G):
    nu = symp_eigs(G)
    up, dn = nu + 0.5, nu - 0.5
    return float(np.sum(up * np.log(up)) -
                 np.sum(np.where(dn > 1e-12,
                                 dn * np.log(np.clip(dn, 1e-300, None)), 0)))


def pair_cov(G, i, j, n):
    """4x4 covariance of modes (i,j) in basis (q_i,q_j,p_i,p_j)."""
    qi, qj, pi, pj = i, j, n + i, n + j
    sel = [qi, qj, pi, pj]
    return G[np.ix_(sel, sel)]


def logneg_pair(s4):
    """E_N of a two-mode covariance in basis (q1,q2,p1,p2): flip p2."""
    D = np.diag([1.0, 1.0, 1.0, -1.0])
    st = D @ s4 @ D
    Om = np.array([[0, 0, 1, 0], [0, 0, 0, 1], [-1, 0, 0, 0], [0, -1, 0, 0]], float)
    nu = np.sort(np.abs(np.linalg.eigvals(Om @ st).imag))[2:]
    nu = np.clip(nu, 1e-30, None)
    return float(np.sum(np.where(nu < 0.5, -np.log(2 * nu), 0.0)))


def entropy_pair(s4):
    Om = np.array([[0, 0, 1, 0], [0, 0, 0, 1], [-1, 0, 0, 0], [0, -1, 0, 0]], float)
    nu = np.clip(np.sort(np.abs(np.linalg.eigvals(Om @ s4).imag))[2:], 0.5, None)
    up, dn = nu + 0.5, nu - 0.5
    return float(np.sum(up * np.log(up)) -
                 np.sum(np.where(dn > 1e-12,
                                 dn * np.log(np.clip(dn, 1e-300, None)), 0)))


def run(d, verbose=False):
    n = N + 2
    K, ia, ib = build(d)
    A = np.zeros((2 * n, 2 * n))
    A[:n, n:] = np.eye(n)
    A[n:, :n] = -K
    G0 = initial_covariance(n)
    ts = np.linspace(0, T_MAX, NT)
    dt = ts[1] - ts[0]
    Ustep = expm(A * dt)
    G = G0.copy()
    traj = []
    for t in ts:
        s4 = pair_cov(G, ia, ib, n)
        traj.append(dict(t=float(t), EN=logneg_pair(s4), Sab=entropy_pair(s4)))
        G = Ustep @ G @ Ustep.T
    # first significant entanglement peak
    EN = np.array([q['EN'] for q in traj])
    Sab = np.array([q['Sab'] for q in traj])
    if EN.max() < 1e-6:
        return dict(d=d, peak=None, traj=traj)
    k = int(np.argmax(EN))
    return dict(d=d, peak=dict(t=traj[k]['t'], EN=float(EN[k]), Sab=float(Sab[k]),
                               F=float(Sab[k] / EN[k])), traj=traj)


if __name__ == '__main__':
    print(f"N6 — the Landauer rate. chain N={N}, m={m}, probes W^2={W2}, eps={eps}")
    print(f"{'d':>4} {'t*':>7} {'E_N(t*)':>10} {'S_ab(t*)':>10} {'fee/nat F':>10} {'F*d':>8}")
    out = []
    for d in (4, 6, 8, 12, 16, 24):
        r = run(d)
        out.append(r)
        if r['peak']:
            p = r['peak']
            print(f"{d:4d} {p['t']:7.1f} {p['EN']:10.5f} {p['Sab']:10.5f} "
                  f"{p['F']:10.4f} {p['F'] * d:8.3f}")
        else:
            print(f"{d:4d}    — no entanglement transacted in window")
    ds = [r['d'] for r in out if r['peak']]
    Fs = [r['peak']['F'] for r in out if r['peak']]
    if len(ds) >= 3:
        pw = np.polyfit(np.log(ds), np.log(Fs), 1)[0]
        print(f"\nfee-per-nat law: F(d) ~ d^{pw:+.2f}   (diamond pricing predicts -1)")
        print("GATE: " + ("PASSED — bigger rooms charge less per nat." if pw < -0.5 else
                          "NOT PASSED in this arena — see notes."))
    json.dump([dict(d=r['d'], peak=r['peak']) for r in out],
              open('/Users/antoine/agi/ledger/n6_results.json', 'w'), indent=1)
    print("-> n6_results.json")
