#!/usr/bin/env python3
"""
Verify the Fractran machine from Size21/Sz21_140_unofficial_1.lean.

Machine rules (priority order, pattern matching):
  R1: (a, b+1, c+1, d, e) -> (a, b, c, d, e)
  R2: (a, b, c, d+1, e+1) -> (a, b+2, c, d, e)
  R3: (a, b+1, c, d, e)   -> (a+1, b, c, d+2, e)
  R4: (a, b, c, d+2, e)   -> (a, b, c+1, d, e)
  R5: (a+1, b, c, d, e)   -> (a, b+1, c, d, e+1)

Starting state: (1, 0, 0, 0, 0)
"""


def step(state):
    """Apply one step of the Fractran machine. Returns new state or None if halted."""
    a, b, c, d, e = state

    # R1: (a, b+1, c+1, d, e) -> (a, b, c, d, e)
    if b >= 1 and c >= 1:
        return (a, b - 1, c - 1, d, e)

    # R2: (a, b, c, d+1, e+1) -> (a, b+2, c, d, e)
    if d >= 1 and e >= 1:
        return (a, b + 2, c, d - 1, e - 1)

    # R3: (a, b+1, c, d, e) -> (a+1, b, c, d+2, e)
    if b >= 1:
        return (a + 1, b - 1, c, d + 2, e)

    # R4: (a, b, c, d+2, e) -> (a, b, c+1, d, e)  [note: d >= 2]
    if d >= 2:
        return (a, b, c + 1, d - 2, e)

    # R5: (a+1, b, c, d, e) -> (a, b+1, c, d, e+1)  [note: a >= 1]
    if a >= 1:
        return (a - 1, b + 1, c, d, e + 1)

    # Halted
    return None


def run_until(state, predicate, max_steps=10_000_000):
    """Run until predicate(state) is True, return (state, steps)."""
    for i in range(1, max_steps + 1):
        next_state = step(state)
        if next_state is None:
            raise RuntimeError(
                f"Machine halted at {state} after {i-1} steps "
                f"before predicate was satisfied"
            )
        state = next_state
        if predicate(state):
            return state, i
    raise RuntimeError(f"Exceeded {max_steps} steps without satisfying predicate")


def is_canonical(s):
    """Check if state has the form (a, 0, 0, d, 0)."""
    return s[1] == 0 and s[2] == 0 and s[4] == 0


def run_to_canonical(state, max_steps=10_000_000):
    """Run until we reach a state of the form (a, 0, 0, d, 0)."""
    return run_until(state, is_canonical, max_steps)


def run_to_target(state, target, max_steps=10_000_000):
    """Run until we reach exact target state, return (state, steps)."""
    return run_until(state, lambda s: s == target, max_steps)


def print_header(title):
    print()
    print("=" * 70)
    print(f"  {title}")
    print("=" * 70)


def main():
    all_pass = True

    # =========================================================================
    # Verification 1: c0 = (1,0,0,0,0) reaches (3,0,0,2,0)
    # =========================================================================
    print_header("Verification 1: (1,0,0,0,0) ->+ (3,0,0,2,0)")

    c0 = (1, 0, 0, 0, 0)
    target = (3, 0, 0, 2, 0)
    try:
        state, steps = run_to_target(c0, target)
        ok = (state == target)
        print(f"  PASS: (1,0,0,0,0) -> {state} in {steps} steps")
    except RuntimeError as e:
        print(f"  FAIL: {e}")
        all_pass = False

    # Trace the full path for insight
    print("\n  Full trace from (1,0,0,0,0) to (3,0,0,2,0):")
    s = c0
    for i in range(50):
        ns = step(s)
        if ns is None:
            print(f"    Step {i}: {s} -> HALT")
            break
        canon_mark = " <-- canonical" if is_canonical(ns) else ""
        print(f"    Step {i+1}: {s} -> {ns}{canon_mark}")
        s = ns
        if s == target:
            print(f"    Reached target {s} at step {i+1}")
            break

    # =========================================================================
    # Verification 2: Even d=2k, a >= k+1: (a,0,0,2k,0) ->+ (a+k+2, 0,0, 3k+5, 0)
    # =========================================================================
    print_header(
        "Verification 2: (a,0,0,2k,0) ->+ (a+k+2, 0, 0, 3k+5, 0)\n"
        "  for even d=2k, a >= k+1"
    )

    test_cases_2 = []
    for k in range(1, 8):
        for a in [k + 1, k + 2, k + 5, 2 * k + 3]:
            test_cases_2.append((a, k))

    for a, k in test_cases_2:
        start = (a, 0, 0, 2 * k, 0)
        expected = (a + k + 2, 0, 0, 3 * k + 5, 0)
        try:
            result, steps = run_to_target(start, expected)
            print(f"  PASS: k={k}, a={a}: {start} -> {result} in {steps} steps")
        except RuntimeError as e:
            print(f"  FAIL: k={k}, a={a}: {e}")
            all_pass = False

    # =========================================================================
    # Verification 3: d=3, a >= 1: (a,0,0,3,0) ->+ (a+1, 0, 0, 4, 0)
    # =========================================================================
    print_header("Verification 3: (a,0,0,3,0) ->+ (a+1, 0, 0, 4, 0) for a >= 1")

    for a in range(1, 15):
        start = (a, 0, 0, 3, 0)
        expected = (a + 1, 0, 0, 4, 0)
        try:
            result, steps = run_to_target(start, expected)
            print(f"  PASS: a={a}: {start} -> {result} in {steps} steps")
        except RuntimeError as e:
            print(f"  FAIL: a={a}: {e}")
            all_pass = False

    # =========================================================================
    # Verification 4: d=5, a >= 1: (a,0,0,5,0) ->+ (a, 0, 0, 2, 0)
    # =========================================================================
    print_header("Verification 4: (a,0,0,5,0) ->+ (a, 0, 0, 2, 0) for a >= 1")

    for a in range(1, 15):
        start = (a, 0, 0, 5, 0)
        expected = (a, 0, 0, 2, 0)
        try:
            result, steps = run_to_target(start, expected)
            print(f"  PASS: a={a}: {start} -> {result} in {steps} steps")
        except RuntimeError as e:
            print(f"  FAIL: a={a}: {e}")
            all_pass = False

    # =========================================================================
    # Verification 5: Odd d=2k+1, k>=3, a>=k:
    #   (a,0,0,2k+1,0) ->+ (a+k-2, 0, 0, 3k-4, 0)
    # =========================================================================
    print_header(
        "Verification 5: (a,0,0,2k+1,0) ->+ (a+k-2, 0, 0, 3k-4, 0)\n"
        "  for odd d=2k+1, k >= 3, a >= k"
    )

    test_cases_5 = []
    for k in range(3, 12):
        for a in [k, k + 1, k + 3, 2 * k]:
            test_cases_5.append((a, k))

    for a, k in test_cases_5:
        start = (a, 0, 0, 2 * k + 1, 0)
        expected = (a + k - 2, 0, 0, 3 * k - 4, 0)
        try:
            result, steps = run_to_target(start, expected)
            print(f"  PASS: k={k}, a={a}: {start} -> {result} in {steps} steps")
        except RuntimeError as e:
            print(f"  FAIL: k={k}, a={a}: {e}")
            all_pass = False

    # =========================================================================
    # Verification 6: (A, 0, 0, 2, E) with E >= 1: ->* (A+2E, 0, 0, 3E+2, 0)
    # =========================================================================
    print_header(
        "Verification 6: (A,0,0,2,E) ->* (A+2E, 0, 0, 3E+2, 0) for E >= 1"
    )

    test_cases_6 = []
    for E in range(1, 10):
        for A in [0, 1, 2, E, E + 2]:
            test_cases_6.append((A, E))

    for A, E in test_cases_6:
        start = (A, 0, 0, 2, E)
        expected = (A + 2 * E, 0, 0, 3 * E + 2, 0)
        try:
            result, steps = run_to_target(start, expected)
            print(f"  PASS: A={A}, E={E}: {start} -> {result} in {steps} steps")
        except RuntimeError as e:
            print(f"  FAIL: A={A}, E={E}: {e}")
            all_pass = False

    # =========================================================================
    # Verification 7: Predicate P(a,d) = (d >= 2 and 2a >= d+2) is preserved
    # =========================================================================
    print_header("Verification 7: Predicate P(a,d) = (d>=2 and 2a>=d+2) preserved")

    def P(a, d):
        return d >= 2 and 2 * a >= d + 2

    # --- 7a: P is preserved by each transition formula ---

    print("\n  7a. Testing P preservation under Verification 2 transitions:")
    print("      (a,0,0,2k,0) -> (a+k+2, 0, 0, 3k+5, 0)")
    v2_pass = True
    for k in range(1, 50):
        for a in range(k + 1, 4 * k):
            d_in, d_out = 2 * k, 3 * k + 5
            a_out = a + k + 2
            if P(a, d_in) and not P(a_out, d_out):
                print(f"      FAIL: k={k}, a={a}: "
                      f"P({a},{d_in})=True but P({a_out},{d_out})=False")
                v2_pass = False
                all_pass = False
    if v2_pass:
        print("      All PASS")

    print("\n  7b. Testing P preservation under Verification 3 transitions:")
    print("      (a,0,0,3,0) -> (a+1, 0, 0, 4, 0)")
    v3_pass = True
    for a in range(1, 100):
        if P(a, 3) and not P(a + 1, 4):
            print(f"      FAIL: a={a}: P({a},3)=True but P({a+1},4)=False")
            v3_pass = False
            all_pass = False
    if v3_pass:
        print("      All PASS")

    print("\n  7c. Testing P preservation under Verification 4 transitions:")
    print("      (a,0,0,5,0) -> (a, 0, 0, 2, 0)")
    v4_pass = True
    for a in range(1, 100):
        if P(a, 5) and not P(a, 2):
            print(f"      FAIL: a={a}: P({a},5)=True but P({a},2)=False")
            v4_pass = False
            all_pass = False
    if v4_pass:
        print("      All PASS")

    print("\n  7d. Testing P preservation under Verification 5 transitions:")
    print("      (a,0,0,2k+1,0) -> (a+k-2, 0, 0, 3k-4, 0)")
    v5_pass = True
    for k in range(3, 50):
        for a in range(k, 4 * k):
            d_in, d_out = 2 * k + 1, 3 * k - 4
            a_out = a + k - 2
            if P(a, d_in) and not P(a_out, d_out):
                print(f"      FAIL: k={k}, a={a}: "
                      f"P({a},{d_in})=True but P({a_out},{d_out})=False")
                v5_pass = False
                all_pass = False
    if v5_pass:
        print("      All PASS")

    # --- 7e: P holds at the initial canonical point ---

    print("\n  7e. P holds at initial canonical point (3, 0, 0, 2, 0):")
    p_init = P(3, 2)
    status = "PASS" if p_init else "FAIL"
    if not p_init:
        all_pass = False
    print(f"      {status}: P(3, 2) = {p_init}  [2*3=6 >= 2+2=4]")

    # --- 7f: P implies preconditions for each formula ---

    print("\n  7f. P(a,d) implies the preconditions for each transition:")

    print("      Even d=2k: need a >= k+1")
    prec_pass = True
    for k in range(1, 100):
        for a in range(0, 3 * k):
            if P(a, 2 * k) and not (a >= k + 1):
                print(f"        FAIL: k={k}, a={a}: P holds but a < k+1")
                prec_pass = False
                all_pass = False
    if prec_pass:
        print("        PASS: P(a, 2k) => 2a >= 2k+2 => a >= k+1")

    print("      d=3: need a >= 1")
    # P(a,3) requires 2a >= 5, i.e. a >= 3
    prec_d3 = not P(0, 3) and not P(1, 3) and not P(2, 3) and P(3, 3)
    status = "PASS" if prec_d3 else "FAIL"
    if not prec_d3:
        all_pass = False
    print(f"        {status}: P(a,3) requires a >= 3, which implies a >= 1")

    print("      d=5: need a >= 1")
    # P(a,5) requires 2a >= 7, i.e. a >= 4
    prec_d5 = not P(3, 5) and P(4, 5)
    status = "PASS" if prec_d5 else "FAIL"
    if not prec_d5:
        all_pass = False
    print(f"        {status}: P(a,5) requires a >= 4, which implies a >= 1")

    print("      Odd d=2k+1, k >= 3: need a >= k")
    prec_pass2 = True
    for k in range(3, 100):
        for a in range(0, 3 * k):
            if P(a, 2 * k + 1) and not (a >= k):
                print(f"        FAIL: k={k}, a={a}: P holds but a < k")
                prec_pass2 = False
                all_pass = False
    if prec_pass2:
        print("        PASS: P(a, 2k+1) => 2a >= 2k+3 => a >= k+2 >= k")

    # --- 7g: Analytical proof ---

    print("\n  7g. Analytical argument that P is preserved:")
    print("      P(a, d) := d >= 2 and 2a >= d + 2")
    print()
    print("      V2 (even, d=2k, a>=k+1):")
    print("        d_out = 3k+5, a_out = a+k+2")
    print("        2*a_out - d_out - 2 = 2(a+k+2) - (3k+5) - 2 = 2a - k - 3")
    print("        P_in => 2a >= 2k+2 => 2a - k - 3 >= 2k+2 - k - 3 = k - 1 >= 0")
    print("        So P is preserved for k >= 1.")
    print()
    print("      V3 (d=3->4):")
    print("        2*(a+1) - 4 - 2 = 2a - 4")
    print("        P(a,3) => 2a >= 5 => 2a - 4 >= 1 >= 0. OK.")
    print()
    print("      V4 (d=5->2):")
    print("        2*a - 2 - 2 = 2a - 4")
    print("        P(a,5) => 2a >= 7 => 2a - 4 >= 3 >= 0. OK.")
    print()
    print("      V5 (odd, d=2k+1, k>=3, a>=k):")
    print("        d_out = 3k-4, a_out = a+k-2")
    print("        2*a_out - d_out - 2 = 2(a+k-2) - (3k-4) - 2 = 2a - k - 2")
    print("        P_in => 2a >= 2k+3 => 2a - k - 2 >= 2k+3 - k - 2 = k + 1 >= 4")
    print("        So P is preserved for k >= 3.")

    # =========================================================================
    # Summary
    # =========================================================================
    print()
    print("=" * 70)
    if all_pass:
        print("  ALL VERIFICATIONS PASSED")
    else:
        print("  SOME VERIFICATIONS FAILED")
    print("=" * 70)
    print()

    # =========================================================================
    # Bonus: Long-run trajectory from (1,0,0,0,0)
    # =========================================================================
    print_header("Bonus: Long-run trajectory from (1,0,0,0,0)")

    def P_check(a, d):
        return d >= 2 and 2 * a >= d + 2

    s = (1, 0, 0, 0, 0)
    canonical_points = []
    total_steps = 0
    for _ in range(30):
        try:
            s_next, steps = run_to_canonical(s)
            total_steps += steps
            canonical_points.append((s_next, total_steps))
            s = s_next
        except RuntimeError:
            print(f"  Stopped after {total_steps} total steps (exceeded limit)")
            break

    print("  Successive canonical states (a, 0, 0, d, 0):")
    print(f"  {'State':<30s} {'Cum. Steps':>12s}  {'P(a,d)':>8s}  {'a/d ratio':>10s}")
    print(f"  {'-'*30} {'-'*12}  {'-'*8}  {'-'*10}")
    for (st, cum_steps) in canonical_points:
        a, _, _, d, _ = st
        p_holds = P_check(a, d)
        ratio = f"{a/d:.4f}" if d > 0 else "inf"
        print(f"  {str(st):<30s} {cum_steps:>12d}  {str(p_holds):>8s}  {ratio:>10s}")

    # Growth analysis
    if len(canonical_points) >= 10:
        a_vals = [s[0] for s, _ in canonical_points]
        d_vals = [s[3] for s, _ in canonical_points]
        print(f"\n  Growth ratios (a[i+1]/a[i]) for last 10 points:")
        for i in range(len(a_vals) - 10, len(a_vals) - 1):
            if a_vals[i] > 0:
                print(f"    a[{i+1}]/a[{i}] = {a_vals[i+1]}/{a_vals[i]}"
                      f" = {a_vals[i+1]/a_vals[i]:.4f}")
        print(f"\n  The growth ratio converges to ~3/2 = 1.5,")
        print(f"  consistent with unbounded growth (non-halting).")


if __name__ == "__main__":
    main()
