#!/usr/bin/python3
# SPDX-FileCopyrightText: 2023 Justin Blanchard <UncombedCoconut@gmail.com>
# SPDX-License-Identifier: Apache-2.0 OR MIT

from argparse import ArgumentParser
from bit_set import BitSet
from glob import glob
from matplotlib import pyplot as plt
from matplotlib.pyplot import annotate
from supervenn import supervenn
from tqdm import tqdm
from unittest.mock import patch

ap = ArgumentParser(add_help=False)
ap.add_argument('-u', '--universe', help='Optional BitSet file to restrict to')
ap.add_argument('-p', '--preserve-order', help='Keep the rows in the same order as the BitSets on the command line', action='store_true')
ap.add_argument('-w', '--width', help='Width, in the coordinate system used by Matplotlib', type=int, default=12)
ap.add_argument('-h', '--height', help='Height, in the coordinate system used by Matplotlib', type=int, default=9)
ap.add_argument('-o', '--output', help='Output PNG file', default='venn.png')
ap.add_argument('sets', help='BitSet files describing the sets to plot', nargs='*', default=glob('*.bs'))
ap.add_argument('--help', action='help', help='show this help message and exit')  # Stupid module, I only wanted to override short option "-h".
args = ap.parse_args()

sets_ordering = None if args.preserve_order else 'minimize gaps'
universe = BitSet.open(args.universe) if args.universe else BitSet.universe()

class Lengthy:
    def __init__(self, l): self._l = l
    def __len__(self): return self._l

def break_into_chunks(sets):
    with tqdm(total=2**len(sets)) as progress:
        return _break_into_chunks(sets, {}, universe, 0, frozenset(), progress)

def _break_into_chunks(sets, chunks, universe_slice, level, occurs_in, progress):
    if not universe_slice:
        progress.update(2**(len(sets)-level))
        return
    if level >= len(sets):
        progress.update()
        chunks[frozenset(occurs_in)] = Lengthy(len(universe_slice))
    else:
        _break_into_chunks(sets, chunks, universe_slice - sets[level], level+1, occurs_in, progress)
        _break_into_chunks(sets, chunks, universe_slice & sets[level], level+1, occurs_in.union((level,)), progress)
    return chunks

def override_annotate(label, **kwargs):
    from unittest.mock import call
    if kwargs.get('rotation') == 90:
        kwargs['xy'] = (kwargs['xy'][0], kwargs['xy'][1] + 1/args.height)
        kwargs['va'] = 'top'
    return annotate(label, **kwargs)

patch('supervenn._algorithms.break_into_chunks', break_into_chunks).start()
patch('supervenn._plots.break_into_chunks', break_into_chunks).start()
patch('matplotlib.pyplot.annotate', override_annotate).start()

# NOTE: repopulate using conversion_hacks.py
id_sets = {name[:-3]: BitSet.open(name) & universe for name in args.sets}
id_sets = {k: v for (k, v) in id_sets.items() if v}

plt.figure(figsize=(args.width, args.height))
supervenn(list(id_sets.values()), list(id_sets), sets_ordering=sets_ordering, widths_minmax_ratio=0.1, rotate_col_annotations=True, col_annotations_area_height=2, fontsize=9)
plt.xlabel(f'Beavers / {len(universe)}')
plt.ylabel('Deciders')
plt.savefig(args.output)
plt.close()
print('Saved', args.output)
