"""
n13 — The shadow kill-gate. Does pure dephasing (elastic scattering that
creates records, zero mean absorption) produce a directed force between
two scatterers in the vacuum of a 1D scalar field?

Le Sage gravity, quantum edition: if "decoherence shadows" could attract,
gravity could be the vacuum's bookkeeping of records. Pre-registered
criteria: method must reproduce static Casimir vs -dE0/dd to ~10%;
single-grain null; verdict on sign, gamma_A*gamma_B scaling, d-dependence.

VERDICT (July 2026): REFUTED. Dephasing grains repel (the repulsion is
proportional to the medium's absorption and vanishes in the transparent
limit); a real thermal bath makes it worse; the SAME scatterer, coherent,
attracts (Casimir). Maxwell's theorem survives decoherence: what a grain
rediffuses with scrambled phase carries exactly the momentum it removed.
Decoherence does not create attraction -- it destroys it.

Exact NESS via Lyapunov equations; "vacuum continuation" boundary baths
(undriven NESS = exact ground state to 7e-13). Force read from the ABBA-
subtracted momentum-flux (T^xx) discontinuity across each grain.
Runtime: ~1 minute (numpy + scipy).
"""

import numpy as np
from scipy.linalg import solve_continuous_lyapunov, eigh

N        = 256
MU2      = 4e-4
LABS     = 28
ETA_MAX  = 0.5
ETA_BULK = 0.005
CENTER   = N // 2
SMEAR    = 2.0
OFF      = 8          # bonds de mesure hors du support du grain (+-8 sites)

def K_matrix():
    return np.diag(np.full(N, 2.0 + MU2)) \
        - np.diag(np.ones(N - 1), 1) - np.diag(np.ones(N - 1), -1)

def weight(site):
    """Grain = indice fluctuant couple a la deformation locale :
    V = (c^T x)^2/2 avec c = gradient discret d'une gaussienne.
    Somme nulle (IR-sur), lisse (pas de bord de bande)."""
    idx = np.arange(N)
    g = np.exp(-0.5 * ((idx - site) / SMEAR) ** 2)
    g[np.abs(idx - site) > 4 * SMEAR] = 0.0
    c = np.zeros(N)
    c[:-1] = np.diff(g)
    return c / np.linalg.norm(c)

def ground_cov(K, temp=0.0):
    w2, U = eigh(K)
    w = np.sqrt(np.maximum(w2, 1e-12))
    occ = 0.5 * np.ones(N) if temp == 0 else 0.5 / np.tanh(0.5 * w / temp)
    sig = np.zeros((2 * N, 2 * N))
    sig[:N, :N] = U @ np.diag(occ / w) @ U.T
    sig[N:, N:] = U @ np.diag(occ * w) @ U.T
    return sig

def build_A_D(K, eta_bulk, temp):
    eta = np.full(N, eta_bulk)
    for j in range(LABS):
        s = ((LABS - j) / LABS) ** 2
        eta[j] += ETA_MAX * s
        eta[N - 1 - j] += ETA_MAX * s
    E = np.diag(np.concatenate([eta, eta]))
    A_H = np.zeros((2 * N, 2 * N))
    A_H[:N, N:] = np.eye(N)
    A_H[N:, :N] = -K
    A = A_H - E / 2
    sgs = ground_cov(K, temp)
    D = 0.5 * (E @ sgs + sgs @ E)
    return A, D, sgs

def ness(K, noisy=(), eta_bulk=ETA_BULK, temp=0.0):
    """noisy = ((site, gamma), ...) ; grains etales de profil `weight`."""
    A, D, sgs = build_A_D(K, eta_bulk, temp)
    sig0 = solve_continuous_lyapunov(A, -D)
    if not noisy:
        return sig0, np.zeros(0)
    gammas = np.array([g for _, g in noisy])
    uvecs, vvecs, parts = [], [], []
    for s, _ in noisy:
        w = weight(s)
        u = np.concatenate([w, np.zeros(N)])       # mesure <x_tilde^2>
        v = np.concatenate([np.zeros(N), w])       # diffusion sur p_tilde
        uvecs.append(u); vvecs.append(v)
        parts.append(solve_continuous_lyapunov(A, -np.outer(v, v)))
    n = len(noisy)
    b = np.array([u @ sig0 @ u for u in uvecs])
    M = np.array([[gammas[j] * (uvecs[i] @ parts[j] @ uvecs[i])
                   for j in range(n)] for i in range(n)])
    ev = np.linalg.eigvals(M)
    if np.max(ev.real) >= 1.0:
        raise RuntimeError(f"instabilite parametrique (max eig M = {np.max(ev.real):.3f})")
    svals = np.linalg.solve(np.eye(n) - M, b)
    sig = sig0 + sum(gammas[j] * svals[j] * parts[j] for j in range(n))
    return sig, svals

def stress_profile(sig, K):
    sxx = sig[:N, :N]; spp = sig[N:, N:]
    pp = 0.5 * (np.diag(spp)[:-1] + np.diag(spp)[1:])
    dx2 = np.diag(sxx)[:-1] + np.diag(sxx)[1:] - 2 * np.diag(sxx, 1)
    xx = 0.5 * (np.diag(sxx)[:-1] + np.diag(sxx)[1:])
    return 0.5 * pp + 0.5 * dx2 - 0.5 * MU2 * xx

def force_att(T_by, a, b, off=OFF):
    dT = T_by["AB"] - T_by["A0"] - T_by["0B"] + T_by["00"]
    FA = dT[a - off] - dT[a + off - 1]
    FB = dT[b + off - 1] - dT[b - off]
    return 0.5 * (FA + FB), FA, FB

def run_pair(K, a, b, gA, gB, eta_bulk=ETA_BULK, temp=0.0):
    T_by = {}
    svals = None
    for tag in ("00", "A0", "0B", "AB"):
        noisy = []
        if "A" in tag: noisy.append((a, gA))
        if "B" in tag: noisy.append((b, gB))
        sig, s = ness(K, tuple(noisy), eta_bulk, temp)
        if tag == "AB": svals = s
        T_by[tag] = stress_profile(sig, K)
    F, FA, FB = force_att(T_by, a, b)
    return F, FA, FB, svals

if __name__ == "__main__":
    K = K_matrix()

    print("--- test nul (grain unique etale, gamma=0.01) ---")
    a = CENTER
    sig, s = ness(K, ((a, 0.01),))
    sig0, _ = ness(K)
    T = stress_profile(sig, K) - stress_profile(sig0, K)
    print(f"  <x~^2> = {s[0]:.4f}, rayonnement DT = {T[a-OFF]:.3e}, "
          f"parasite = {T[a-OFF]-T[a+OFF-1]:.3e}\n")

    print("--- deux grains dephasants, vide (gamma=0.01) ---")
    print(f"{'d':>4} {'F_att':>12} {'P_pompe':>10}")
    for d in (16, 24, 32, 48, 64):
        aa, bb = CENTER - d // 2, CENTER - d // 2 + d
        F, FA, FB, s = run_pair(K, aa, bb, 0.01, 0.01)
        print(f"{d:>4} {F:>12.4e} {0.005*s.sum():>10.3e}")
    print()

    print("--- scaling en gamma (d=24) ---")
    print(f"{'gA':>7} {'gB':>7} {'F_att':>12} {'F/(gA*gB)':>12}")
    aa, bb = CENTER - 12, CENTER + 12
    for gA, gB in ((0.0025, 0.0025), (0.005, 0.005), (0.01, 0.01),
                   (0.02, 0.02), (0.02, 0.005)):
        F, *_ = run_pair(K, aa, bb, gA, gB)
        print(f"{gA:>7} {gB:>7} {F:>12.4e} {F/(gA*gB):>12.4e}")
    print()

    print("--- dependance a l'absorption (d=24, gamma=0.01) ---")
    print(f"{'eta_bulk':>9} {'F_att':>12}")
    for eb in (0.02, 0.01, 0.005, 0.0025, 0.00125):
        F, *_ = run_pair(K, aa, bb, 0.01, 0.01, eta_bulk=eb)
        print(f"{eb:>9} {F:>12.4e}")
    print()

    print("--- bain reel thermique (d=24, gamma=0.005) ---")
    print(f"{'T':>6} {'F_att':>12} {'P_pompe':>10}")
    for temp in (0.0, 0.25, 0.5, 1.0, 2.0):
        F, FA, FB, s = run_pair(K, aa, bb, 0.005, 0.005, temp=temp)
        print(f"{temp:>6} {F:>12.4e} {0.0025*s.sum():>10.3e}")
