"""
n19 — The exact non-Gaussian test. Does the record-shadow attraction of n18
survive when the interacting chain is solved EXACTLY (full Lindblad, no
Hartree, no Gaussian approximation)?

n18 (Hartree) found: dephasing scatterers attract in a quartic medium, with
the force carried by a flux x medium-response interplay term. A skeptic's
objection stands: Hartree *inserts by hand* a Kerr-like medium (stiffness
responding to intensity). This script removes the approximation entirely:
a 6-site anharmonic chain, H = sum p^2/2 + (x_{i+1}-x_i)^2/2 + mu^2 x^2/2
+ (g/4) x^4, boundary sites damped (T=0 amplitude damping, rate kappa),
two "grains" = pure x^2 dephasing Lindblad noise (records, zero mean
absorption), steady state by exact integration of the full Liouvillian,
force from the ABBA-subtracted T^xx profile (including the quartic term
-(g/4)x^4 in the stress).

PRE-REGISTERED GATES
 V1  Code + truncation validation: with LINEAR jumps sqrt(gamma) x_j the
     dynamics is exactly Gaussian; the exact solver must match a Lyapunov
     twin of the same 6-site system (covariances and ABBA force) to
     truncation error (<~1% at n_max=4).
 V2  Sign calibration: static stiffness bumps at the grain sites must
     attract (Casimir; Kenneth-Klich), fixing the sign convention.
 S   The verdict scan: x^2-noise grains, g in {0, 0.5, 1, 2, 3}.
     HARTREE VINDICATED if F_att(g) rises from the g=0 value and turns
     attractive at strong coupling; HARTREE ARTIFACT (and n18 demoted to
     a statement about Kerr-like media only) if F stays at or below its
     g=0 value across the scan.
 T   Truncation honesty: top-Fock occupation reported for every run;
     any verdict row with P(n_top) > 3% is flagged unusable; n_max=4
     spot-check of the verdict rows.

Geometry: sites 0..5; baths at 0 and 5; grains at 1 and 4; force read
from the T^xx jumps across each grain (bonds 0-1 vs 1-2, and 3-4 vs 4-5).
Runtime: n_max=3 scan ~ 1 h; n_max=4 spot-check ~ overnight (RK4 path).
"""

import sys
import time
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import expm_multiply
from scipy.linalg import solve_continuous_lyapunov

L      = int(sys.argv[2]) if len(sys.argv) > 2 else 6
MU2    = 1.0       # MASSIVE medium, deliberately: with the near-massless
                   # mu^2 = 4e-4 of n13/n18, the chain GROUND STATE itself
                   # overflows a small local Fock ladder (IR modes at
                   # omega ~ 0.02 give n_bar ~ 0.7 per site; P_top hit 27%
                   # with zero heating — caught twice by the V1 null gate,
                   # chronicle item 7). The flip question is not an IR
                   # question; at mu^2 = 1, n_bar ~ 0.04 and the truncation
                   # converges. Range physics is out of scope at d = 3.
BATH   = "site"    # "site": a_j damping everywhere (+KAPPA at the ends) —
                   # converges fast; its truncation bias on the FORCE is
                   # measured directly by the V1 twin comparison (gate:
                   # bias < 25% of the predicted S-scan peak).
                   # "mode": vacuum-continuation bath — exact ground state
                   # at g=0 in principle, but truncated mode operators
                   # develop slow quasi-conserved dynamics (kept for
                   # reference; chronicle).
KAPPA  = 0.2       # extra end-site damping when BATH == "site"
ETA    = 0.1       # NORMAL-MODE damping: each eigenmode of the bare chain
                   # is drained toward its own vacuum ("vacuum continuation",
                   # n13's bath philosophy ported to the exact code). The
                   # bisection record that forced this design: (run 1) end
                   # baths only -> interior heats, P_top 29%; (run 3) local
                   # site damping fights the hopping and the truncation
                   # rectifies the amputated squeezing into fake heat
                   # (L=2 study: n_bar 0.48 at n_max=3 vs 0.08 exact,
                   # converged only at n_max >= 8). Mode damping has zero
                   # conflict with H at g=0: the undriven NESS is the exact
                   # chain ground state, and small ladders suffice again.
                   # Chronicle items 6-8.
GLIN   = 0.05      # V1 thermal-bath rate on grain sites
NBAR   = 0.2       # V1 thermal-bath occupation
GSQ    = 0.01      # x^2 dephasing rate for the verdict scan
W0     = np.sqrt(2.0 + MU2)   # local frequency; sets the Fock ladder scale
BATHS  = (0, L - 1)
GRAINS = (1, L - 2)
T0     = time.time()

def say(msg):
    print(f"[{time.time()-T0:7.1f}s] {msg}", flush=True)

def bare_K():
    K = np.zeros((L, L))
    for j in range(L):
        cj = 1.0 if 0 < j < L - 1 else 0.5
        K[j, j] = 2 * cj + MU2
    for j in range(L - 1):
        K[j, j + 1] = K[j + 1, j] = -1.0
    return K

def bare_modes():
    return np.linalg.eigh(bare_K())

def gs_cov_bare():
    w2, U = bare_modes()
    w = np.sqrt(w2)
    S = np.zeros((2 * L, 2 * L))
    S[:L, :L] = U @ np.diag(1 / (2 * w)) @ U.T
    S[L:, L:] = U @ np.diag(w / 2) @ U.T
    return S

def twin_damping(A, D):
    """Apply the base bath to a Gaussian (A, D) pair — same model as
    Chain.base_jumps, in covariance language."""
    if BATH == "site":
        rates = {j: ETA for j in range(L)}
        for b in BATHS:
            rates[b] += KAPPA
        for j, r in rates.items():
            A[j, j] -= r / 2
            A[L + j, L + j] -= r / 2
            D[j, j] += r / (2 * W0)
            D[L + j, L + j] += r * W0 / 2
    else:
        A -= (ETA / 2) * np.eye(2 * L)
        D += ETA * gs_cov_bare()
    return A, D

# ---------------------------------------------------------------- operators

def site_matrices(nmax):
    a = np.diag(np.sqrt(np.arange(1, nmax)), 1)
    x = (a + a.T) / np.sqrt(2 * W0)
    p = 1j * (a.T - a) * np.sqrt(W0 / 2)
    return a, x, p

def embed(op, j, nmax):
    m = sp.identity(1, format="csr", dtype=complex)
    for k in range(L):
        m = sp.kron(m, sp.csr_matrix(op) if k == j else sp.identity(nmax),
                    format="csr")
    return m

class Chain:
    def __init__(self, nmax, g):
        self.nmax, self.g = nmax, g
        a, x, p = site_matrices(nmax)
        self.X  = [embed(x, j, nmax) for j in range(L)]
        self.P2 = [embed((p @ p).real, j, nmax) for j in range(L)]
        self.P  = [embed(p, j, nmax) for j in range(L)]
        self.X2 = [embed(x @ x, j, nmax) for j in range(L)]
        self.X4 = [embed(x @ x @ x @ x, j, nmax) for j in range(L)]
        self.A  = [embed(a, j, nmax) for j in range(L)]
        # bond operators (x_{j+1}-x_j)^2
        self.DX2 = [(self.X[j + 1] - self.X[j]) @ (self.X[j + 1] - self.X[j])
                    for j in range(L - 1)]
        # Hamiltonian: 1/2 p^2 + c_j x^2 + mu^2/2 x^2 + g/4 x^4 - sum x x'
        H = sp.csr_matrix((nmax ** L, nmax ** L), dtype=complex)
        for j in range(L):
            cj = 1.0 if 0 < j < L - 1 else 0.5
            H = H + 0.5 * self.P2[j] + (cj + MU2 / 2) * self.X2[j] \
                + (g / 4) * self.X4[j]
        for j in range(L - 1):
            H = H - self.X[j] @ self.X[j + 1]
        self.H = H.tocsr()
        # top-Fock projector (per site, summed) for truncation honesty
        top = np.zeros(nmax); top[-1] = 1.0
        self.PTOP = sum(embed(np.diag(top), j, nmax) for j in range(L)) / L

    def liouvillian(self, jumps):
        """Explicit sparse Lv, column-stacking convention."""
        d = self.H.shape[0]
        I = sp.identity(d, format="csr", dtype=complex)
        Lv = -1j * (sp.kron(I, self.H) - sp.kron(self.H.T, I))
        for c in jumps:
            cd = c.conj().T.tocsr()
            cdc = (cd @ c).tocsr()
            Lv = Lv + sp.kron(c.conj(), c) \
                - 0.5 * sp.kron(I, cdc) - 0.5 * sp.kron(cdc.T, I)
        return Lv.tocsr()

    def ness(self, jumps, label, tol=1e-8, tchunk=3.0, maxchunk=60):
        d = self.H.shape[0]
        Lv = self.liouvillian(jumps)
        rho = np.zeros((d, d), dtype=complex)
        rho[0, 0] = 1.0                      # product vacuum
        v = rho.flatten(order="F")
        for it in range(maxchunk):
            v = expm_multiply(Lv * tchunk, v)
            tr = v.reshape((d, d), order="F").trace().real
            v = v / tr
            res = np.linalg.norm(Lv @ v)
            if it % 5 == 0 or res < tol:
                say(f"      [{label}] chunk {it:2d}  ||Lv rho|| = {res:.2e}")
            if res < tol:
                break
        else:
            say(f"      [{label}] WARN not converged (res {res:.2e})")
        return v.reshape((d, d), order="F")

    def tr(self, op, rho):
        return (op.multiply(rho.T)).sum().real

    def stress(self, rho):
        """T^xx per bond, incl. the quartic term."""
        T = np.zeros(L - 1)
        for j in range(L - 1):
            pp = 0.5 * (self.tr(self.P2[j], rho) + self.tr(self.P2[j + 1], rho))
            dxx = self.tr(self.DX2[j], rho)
            xx = 0.5 * (self.tr(self.X2[j], rho) + self.tr(self.X2[j + 1], rho))
            x4 = 0.5 * (self.tr(self.X4[j], rho) + self.tr(self.X4[j + 1], rho))
            T[j] = 0.5 * pp + 0.5 * dxx - 0.5 * MU2 * xx - (self.g / 4) * x4
        return T

    def base_jumps(self):
        if BATH == "site":
            return [np.sqrt(KAPPA) * self.A[b] for b in BATHS] \
                 + [np.sqrt(ETA) * self.A[j] for j in range(L)]
        # vacuum-continuation bath: damp the BARE chain's normal modes
        w2, U = bare_modes()
        ops = []
        for k in range(L):
            wk = np.sqrt(w2[k])
            Bk = None
            for j in range(L):
                if abs(U[j, k]) < 1e-14:
                    continue
                t = U[j, k] * (np.sqrt(wk / 2) * self.X[j]
                               + (1j / np.sqrt(2 * wk)) * self.P[j])
                Bk = t if Bk is None else Bk + t
            ops.append((np.sqrt(ETA) * Bk).tocsr())
        return ops

    def force_abba(self, noise_kind, gam, label):
        """ABBA force between the two grains; returns (F, FA, FB, ptop)."""
        base = self.base_jumps()
        def grain_ops(sites):
            if noise_kind == "lin":   # thermal pair: bounded, exactly Gaussian
                ops = []
                for s in sites:
                    ops.append(np.sqrt(gam * (NBAR + 1)) * self.A[s])
                    ops.append(np.sqrt(gam * NBAR) * self.A[s].conj().T.tocsr())
                return ops
            return [np.sqrt(gam) * self.X2[s] for s in sites]
        T_by, ptop = {}, 0.0
        for tag, sites in (("00", ()), ("A0", (GRAINS[0],)),
                           ("0B", (GRAINS[1],)), ("AB", GRAINS)):
            rho = self.ness(base + grain_ops(sites), f"{label}:{tag}")
            T_by[tag] = self.stress(rho)
            if tag == "AB":
                ptop = self.tr(self.PTOP, rho)
        dT = T_by["AB"] - T_by["A0"] - T_by["0B"] + T_by["00"]
        a, b = GRAINS
        FA = dT[a - 1] - dT[a]
        FB = dT[b] - dT[b - 1]
        return 0.5 * (FA + FB), FA, FB, ptop

# ------------------------------------------------- Gaussian twin (V1 gate)

def lyapunov_twin(gam_lin):
    """Same system, linear jumps: exact Gaussian NESS."""
    K = bare_K()
    def cov(noisy):
        A = np.zeros((2 * L, 2 * L))
        A[:L, L:] = np.eye(L)
        A[L:, :L] = -K
        D = np.zeros((2 * L, 2 * L))
        A, D = twin_damping(A, D)
        for s in noisy:   # thermal pair {a, a+} at rate gam_lin, occupation NBAR
            A[s, s] -= gam_lin / 2
            A[L + s, L + s] -= gam_lin / 2
            D[s, s] += gam_lin * (2 * NBAR + 1) / (2 * W0)
            D[L + s, L + s] += gam_lin * (2 * NBAR + 1) * W0 / 2
        return solve_continuous_lyapunov(A, -D)
    def stress(sig):
        T = np.zeros(L - 1)
        for j in range(L - 1):
            pp = 0.5 * (sig[L + j, L + j] + sig[L + j + 1, L + j + 1])
            dxx = sig[j, j] + sig[j + 1, j + 1] - 2 * sig[j, j + 1]
            xx = 0.5 * (sig[j, j] + sig[j + 1, j + 1])
            T[j] = 0.5 * pp + 0.5 * dxx - 0.5 * MU2 * xx
        return T
    T_by = {tag: stress(cov(sites)) for tag, sites in
            (("00", ()), ("A0", (GRAINS[0],)), ("0B", (GRAINS[1],)),
             ("AB", GRAINS))}
    dT = T_by["AB"] - T_by["A0"] - T_by["0B"] + T_by["00"]
    a, b = GRAINS
    FA = dT[a - 1] - dT[a]
    FB = dT[b] - dT[b - 1]
    return 0.5 * (FA + FB), cov(GRAINS)

# ---------------------------------------- Hartree twin on the same arena --

def hartree_twin_force(g, gam):
    """n13/n18-style self-consistent Gaussian + Hartree treatment of the
    x^2-record noise, on the exact test's own geometry. Defines what
    'Hartree predicts' for this arena. Mapping: D[sqrt(gam) x^2] drives
    d<p^2>/dt = 4 gam <x^2>, i.e. the n13 scheme with gam_eff = 4 gam."""
    def build(mu2v, noisy):
        K = bare_K()
        for j in range(L):
            K[j, j] += mu2v[j] - MU2      # Hartree dressing on top of bare
        A = np.zeros((2 * L, 2 * L))
        A[:L, L:] = np.eye(L)
        A[L:, :L] = -K
        D = np.zeros((2 * L, 2 * L))
        A, D = twin_damping(A, D)
        sig0 = solve_continuous_lyapunov(A, -D)
        if not noisy:
            return sig0, sig0
        parts, b_v = [], []
        for s_ in noisy:
            src = np.zeros((2 * L, 2 * L))
            src[L + s_, L + s_] = 1.0
            parts.append(solve_continuous_lyapunov(A, -src))
            b_v.append(sig0[s_, s_])
        ge = 4.0 * gam
        n = len(noisy)
        M = np.array([[ge * parts[jj][noisy[ii], noisy[ii]] for jj in range(n)]
                      for ii in range(n)])
        s_v = np.linalg.solve(np.eye(n) - M, np.array(b_v))
        sig = sig0 + sum(ge * s_v[j] * parts[j] for j in range(n))
        return sig, sig0

    def converged(noisy):
        mu2v = np.full(L, MU2)
        for _ in range(300):
            sig, sig0 = build(mu2v, noisy)
            dx2 = np.diag(sig)[:L] - np.diag(sig0)[:L]
            new = 0.4 * mu2v + 0.6 * (MU2 + 3 * g * dx2)
            if np.max(np.abs(new - mu2v)) < 1e-12:
                mu2v = new
                break
            mu2v = new
        sig, _ = build(mu2v, noisy)
        return sig, mu2v

    def stress(sig, mu2v):
        T = np.zeros(L - 1)
        for j in range(L - 1):
            pp = 0.5 * (sig[L + j, L + j] + sig[L + j + 1, L + j + 1])
            dxx = sig[j, j] + sig[j + 1, j + 1] - 2 * sig[j, j + 1]
            xx_j, xx_j1 = sig[j, j], sig[j + 1, j + 1]
            mxx = 0.5 * (mu2v[j] * xx_j + mu2v[j + 1] * xx_j1)
            x4 = 0.5 * (3 * xx_j ** 2 + 3 * xx_j1 ** 2)   # Wick
            T[j] = 0.5 * pp + 0.5 * dxx - 0.5 * mxx - (g / 4) * x4
        return T

    T_by = {}
    for tag, sites in (("00", ()), ("A0", (GRAINS[0],)), ("0B", (GRAINS[1],)),
                       ("AB", GRAINS)):
        sig, mu2v = converged(list(sites))
        T_by[tag] = stress(sig, mu2v)
    dT = T_by["AB"] - T_by["A0"] - T_by["0B"] + T_by["00"]
    a, b = GRAINS
    return 0.5 * ((dT[a - 1] - dT[a]) + (dT[b] - dT[b - 1]))

# ------------------------------------------------------------------- main

if __name__ == "__main__":
    nmax = int(sys.argv[1]) if len(sys.argv) > 1 else 3
    say(f"n19 — exact non-Gaussian test, L={L}, n_max={nmax}")

    # -- V1: linear jumps, g=0: exact vs Gaussian twin ---------------------
    say("V1: linear-jump validation (g=0) against the Lyapunov twin")
    ch = Chain(nmax, 0.0)
    F_ex, FA, FB, ptop = ch.force_abba("lin", GLIN, "V1")
    F_tw, sig_tw = lyapunov_twin(GLIN)
    # thermal jumps modify the drift locally, so the twin ABBA force is small
    # but nonzero: V1 is a quantitative match of exact vs twin
    say(f"V1: exact F = {F_ex:+.6e}   twin F = {F_tw:+.6e}   "
        f"dev = {abs(F_ex - F_tw):.2e}   P(top) = {ptop:.2%}")

    # covariance spot check on the AB config
    rho = ch.ness(ch.base_jumps() + [np.sqrt(GLIN) * ch.X[s] for s in GRAINS],
                  "V1cov")
    dev = 0.0
    for j in range(L):
        xx_ex = ch.tr(ch.X2[j], rho)
        dev = max(dev, abs(xx_ex - sig_tw[j, j]))
    say(f"V1: max |<x^2>_exact - <x^2>_twin| = {dev:.2e}")

    # -- V2: sign calibration ----------------------------------------------
    say("V2: static-bump sign calibration (g=0, no noise)")
    T_by = {}
    for tag, sites in (("00", ()), ("A0", (GRAINS[0],)), ("0B", (GRAINS[1],)),
                       ("AB", GRAINS)):
        chs = Chain(nmax, 0.0)
        for s in sites:
            chs.H = (chs.H + 0.15 * chs.X2[s]).tocsr()
        rho = chs.ness(chs.base_jumps(), f"V2:{tag}")
        T_by[tag] = chs.stress(rho)
    dT = T_by["AB"] - T_by["A0"] - T_by["0B"] + T_by["00"]
    a, b = GRAINS
    Fc = 0.5 * ((dT[a - 1] - dT[a]) + (dT[b] - dT[b - 1]))
    say(f"V2: static bumps F = {Fc:+.6e}   (must be the ATTRACTION sign)")

    # -- H: what Hartree predicts on THIS arena (fast, printed first) -------
    say("H: Hartree-twin prediction on the same 6-site geometry")
    for g in (0.0, 0.25, 0.5, 1.0, 2.0, 3.0, 5.0):
        try:
            Fh = hartree_twin_force(g, GSQ)
            say(f"H: g = {g:<5} F_hartree = {Fh:+.6e}")
        except Exception as e:
            say(f"H: g = {g} failed: {e}")

    # -- S: the verdict scan -------------------------------------------------
    say(f"S: x^2-record noise (gamma={GSQ}), scan over interaction g")
    for g in (0.0, 0.5, 1.0, 2.0, 3.0):
        ch = Chain(nmax, g)
        F, FA, FB, ptop = ch.force_abba("sq", GSQ, f"g={g}")
        flag = "" if ptop < 0.03 else "  [TRUNCATION FLAG]"
        say(f"S: g = {g:<4} F_att = {F:+.6e}  (A {FA:+.2e} / B {FB:+.2e})  "
            f"P(top) = {ptop:.2%}{flag}")
    say("done.")
