#!/usr/bin/python3
# SPDX-FileCopyrightText: 2023 Justin Blanchard <UncombedCoconut@gmail.com>
# SPDX-License-Identifier: Apache-2.0 OR MIT


BYTES = 11083008

class BitSet:
    __hash__ = None

    def __init__(self, iterable=()):
        if isinstance(iterable, BitSet):
            self.data = iterable.data.copy()
        else:
            self.data = bytearray(BYTES)
            for i in iterable: self.add(i)
        self.mv = memoryview(self.data).cast('Q')

    @classmethod
    def universe(cls):
        bs = cls()
        u = bs.mv
        for i in range(len(u)):
            u[i] = 0xffffffffffffffff
        return bs

    @classmethod
    def open(cls, path):
        bs = cls()
        bs.read(path)
        return bs

    def read(self, path):
        with open(path, 'rb') as f:
            f.readinto(self.data)

    def write(self, path):
        with open(path, 'wb') as f:
            f.write(self.data)

    def add(self, i):
        self.data[i//8] |= (1 << (i%8))

    def discard(self, i):
        self.data[i//8] &= ~(1 << (i%8))

    def pop(self):
        i = next(self, None)
        if i is None: raise KeyError('pop from an empty set')
        self.discard(i)
        return i

    def remove(self, i):
        if i not in self: raise KeyError(i)
        self.discard(i)

    def __contains__(self, i):
        return bool(self.data[i//8] & (1 << (i%8)))

    def __len__(self):
        return sum(mask.bit_count() for mask in self.mv)

    def __bool__(self):
        return any(self.mv)

    def copy(self):
        return self.__class__(self)

    def union(self, rhs):
        bs = self.copy()
        bs.update(rhs)
        return bs

    def update(self, rhs):
        if isinstance(rhs, BitSet):
            me = self.mv
            for i, v in enumerate(rhs.mv):
                me[i] |= v
        else:
            for i in rhs:
                self.add(i)

    def difference(self, rhs):
        bs = self.copy()
        bs.difference_update(rhs)
        return bs

    def difference_update(self, rhs):
        if isinstance(rhs, BitSet):
            me = self.mv
            for i, v in enumerate(rhs.mv):
                me[i] &= ~v
        else:
            for i in rhs:
                self.discard(i)

    def intersection(self, rhs):
        bs = self.copy()
        bs.intersection_update(rhs)
        return bs

    def intersection_update(self, rhs):
        if isinstance(rhs, BitSet):
            me = self.mv
            for i, v in enumerate(rhs.mv):
                me[i] &= v
        else:
            for i in self:
                if i not in rhs:
                    self.discard(i)

    def symmetric_difference(self, rhs):
        bs = self.copy()
        bs.symmetric_difference_update(rhs)
        return bs

    def symmetric_difference_update(self, rhs):
        if isinstance(rhs, BitSet):
            me = self.mv
            for i, v in enumerate(rhs.mv):
                me[i] ^= v
        else:
            for i in rhs:
                self.data[i//8] ^= (1 << (i%8))

    def isdisjoint(self, rhs):
        if isinstance(rhs, BitSet):
            return not any(x & y for (x, y) in zip(self.mv, rhs.mv))
        else:
            return not any(i in rhs for i in self)

    def issubset(self, rhs):
        if isinstance(rhs, BitSet):
            return all(x & y == x for (x, y) in zip(self.mv, rhs.mv))
        else:
            return all(i in rhs for i in self)

    def issuperset(self, rhs):
        if isinstance(rhs, BitSet):
            return all(x & y == y for (x, y) in zip(self.mv, rhs.mv))
        else:
            return all(i in self for i in rhs)

    def __eq__(self, rhs):
        if isinstance(rhs, BitSet):
            return self.mv == rhs.mv
        else:
            return False

    def __gt__(self, rhs):
        return self.issuperset(rhs) and len(self) > len(rhs)

    def __lt__(self, rhs):
        return self.issubset(rhs) and len(self) > len(rhs)

    def __ne__(self, rhs):
        if isinstance(rhs, BitSet):
            self.mv != rhs.mv
        else:
            return True

    def __iter__(self):
        for i, v in enumerate(self.mv):
            base = 64 * i - 1
            while v:
                lo_bit = v & -v
                v ^= lo_bit
                yield base + lo_bit.bit_length()

    def incr(f):
        def i_f(self, rhs):
            f(self, rhs)
            return self
        return i_f

    clear = __init__
    __ge__ = issuperset
    __le__ = issubset
    __and__ = intersection
    __rand__ = intersection
    __iand__ = incr(intersection_update)
    __or__ = union
    __ror__ = union
    __ior__ = incr(update)
    __isub__ = incr(difference_update)
    __xor__ = symmetric_difference
    __ixor__ = incr(symmetric_difference_update)
    __rxor__ = symmetric_difference_update
    __add__ = union
    __sub__ = difference


if __name__ == '__main__':
    import os, sys
    from resource import setrlimit, RLIMIT_AS
    setrlimit(RLIMIT_AS, (2**33, 2**34))
    assert len(sys.argv) == 3 and os.path.exists(sys.argv[1]) and not os.path.exists(sys.argv[2])
    _, in_path, out_path = sys.argv
    bs = BitSet()
    with open(in_path, 'rb') as f:
        while (seed := f.read(4)):
            bs.add(int.from_bytes(seed, 'big'))
    bs.write(out_path)
