#!/usr/bin/python3
# SPDX-FileCopyrightText: 2023 Justin Blanchard <UncombedCoconut@gmail.com>
# SPDX-License-Identifier: Apache-2.0 OR MIT

from bit_set import BitSet
from zstandard import ZstdDecompressor
import io, mmap, os

DB_PATH = '../all_5_states_undecided_machines_with_global_header'

def seed_to_desc(seed, n_states=5):
    with open(DB_PATH, 'rb') as db:
        db.seek(30 + n_states*6*seed)
        tm = db.read(n_states*6)
        return tm_to_desc(tm)

def tm_to_desc(tm):
        desc = ''
        for f in range(len(tm)//6):
            if f:
                desc += '_'
            for r in range(2):
                w, d, t = tm[6*f+3*r:6*f+3*r+3]
                if t == 0:
                    desc += '---'
                else:
                    t = chr(ord('A') + t - 1)
                    desc += f'{w}{"RL"[d]}{t}'
        return desc

def desc_to_tm(desc):
    q_tm = (len(desc)+1)//7
    tm = bytearray(6*q_tm)
    for s in range(q_tm):
        if s: assert desc[7*s-1] == '_'
        for b in range(2):
            w, d, t = desc[7*s+3*b : 7*s+3*(b+1)]
            if t in 'Z-':
                continue  # leave it zeroed
            w = ord(w) - ord('0')
            d = {'L': 1, 'R': 0}[d]
            t = ord(t) - ord('A') + 1
            assert 0 <= w <= 1 and 0 <= d <= 1 and 1 <= t <= q_tm
            tm[6*s+3*b : 6*s+3*(b+1)] = (w, d, t)
    return tm

def convert_hs():
    cum = BitSet()
    for n in range(1, 8):
        with open(f'halting_segment_{n}.index', 'rb') as f:
            while (seed := f.read(4)):
                cum.add(int.from_bytes(seed, 'big'))
        print(n, '->', len(cum))
        cum.write(f'HaltingSegment-{n}.bs')

def convert_far(name, iter_stop, dvf_path):
    sets = [BitSet() for _ in range(iter_stop)]
    data = open(dvf_path, 'rb').read()
    pos = 4
    while pos < len(data):
        info_length = int.from_bytes(data[pos+8:pos+12], 'big')
        n = (info_length-1)//2
        if n < iter_stop:
            sets[n].add(int.from_bytes(data[pos:pos+4], 'big'))
        pos += 12 + info_length
    del data
    assert not sets[0], 'WTF is iteration zero?'
    for n in range(1, iter_stop):
        sets[n] |= sets[n-1]
        if sets[n]:
            print(f'{name}-{n}:', len(sets[n]))
            sets[n].write(f'{name}-{n}.bs')

def convert_cps(iter_stop):
    name = 'CPS-LCR_'
    dec = ZstdDecompressor()

    with open(DB_PATH, 'rb') as db_f:
        with mmap.mmap(db_f.fileno(), 0, prot=mmap.PROT_READ) as db:
            undecided_time, undecided_total = int.from_bytes(db[:4], 'big'), int.from_bytes(db[8:12], 'big')
            slices = [(0, undecided_time), (undecided_time, undecided_total)]

            def desc_to_id(desc):
                tm = desc_to_tm(desc)
                l = len(tm)
                for (lo, hi) in slices:
                    while lo < hi:
                        mid = (lo + hi) // 2
                        pos = 30 + l * mid
                        m_lo = db[pos:pos+l]
                        if m_lo < tm: lo = mid + 1
                        elif m_lo == tm: return mid
                        else: hi = mid

            solved = BitSet()
            for n in range(1, iter_stop):
                path_in = f'src/{name}-{n}.txt.zst'
                path_out = f'{name}-{n}.bs'
                if os.path.exists(path_out):
                    solved.read(path_out)
                    print('Trust', path_out, '->', len(solved))
                    continue
                with open(path_in, 'rb') as f:
                    text_stream = io.TextIOWrapper(dec.stream_reader(f), encoding='ascii')
                    for line in text_stream:
                        solved.add(desc_to_id(line.rpartition(',')[-1].strip()))

                if solved:
                    print(f'{name}-{n}:', len(solved))
                    solved.write(path_out)

def convert_bouncers():
    #dec = BitSet.universe()
    dec = BitSet.open('/home/justinb/scratch/beaver/deciders-unrestricted/src/april-undecided.bs')
    #with open('/home/justinb/scratch/beaver/deciders-unrestricted/src/TonyG/Bouncers/Bouncers.umf', 'rb') as f:
    with open('/home/justinb/scratch/beaver/deciders-unrestricted/src/TonyG/Bouncers/Bouncers_april.umf', 'rb') as f:
        while (seed := f.read(4)):
            dec.discard(int.from_bytes(seed, 'big'))
    print('Bouncers ->', len(dec))
    dec.write(f'Bouncers.bs')


#convert_far('FAR-Direct', 9, '/home/justinb/scratch/beaver/bbchallenge-deciders/decider-finite-automata-reduction/output/finite_automata_reduction.dvf')
#convert_far('FAR-MitM', 12, '/home/justinb/scratch/beaver/deciders-unrestricted/FAR-mitm/output/finite_automata_reduction.dvf')
#convert_cps(8)
#convert_bouncers()
