#!/usr/bin/env python3
#
# Sniffles2
# A fast structural variant caller for long-read sequencing data
#
# Created:     28.05.2021
# Author:      Moritz Smolka
# Maintainer:  Hermann Romanek
# Contact:     sniffles@romanek.at
#
import logging
import logging.config
import multiprocessing
from collections import deque
from typing import Optional

from sniffles.utils.resmon import ResourceMonitor

import sys

if not sys.version_info >= (3, 10):
    print(f"Error: Sniffles2 must be run with Python version 3.10 or above (detected Python version: {sys.version_info.major}.{sys.version_info.minor}). Exiting")
    exit(1)

import math
import time
import os

import pysam

from sniffles.config import SnifflesConfig
from sniffles import vcf
from sniffles import snf
from sniffles import parallel
from sniffles import util

# TODO: Dev/Debugging only - Remove for prod
DEV_MONITOR_MEM = False

if DEV_MONITOR_MEM:
    try:
        import psutil
    except ImportError:
        logging.getLogger('sniffles.memory').warning('psutil not available.')

        def dbg_get_total_memory_usage_MB():
            pass
    else:
        logging.getLogger('sniffles.memory').info('Watching memory')


        def dbg_get_total_memory_usage_MB():
            total = 0
            n = 0
            proc = psutil.Process(os.getpid())
            for child in proc.children(recursive=True):
                total += child.memory_info().rss
                n += 1
            total += proc.memory_info().rss
            return total / (1000.0 * 1000.0)


# """
# END:TODO


def Sniffles2_Main(processes: list[parallel.SnifflesWorker]):
    #
    # Determine Sniffles2 run mode
    #
    config = SnifflesConfig()

    if config.dev_debug:
        try:
            import pydevd_pycharm
        except ImportError:
            logging.getLogger('sniffles.dev').info(f'pydevd not available!')
        else:
            try:
                pydevd_pycharm.settrace('localhost', port=config.dev_debug, stdoutToServer=True, stderrToServer=True)
            except:  # noqa
                logging.getLogger('sniffles.dev').info(f'Failed to connect to PyCharm debugger.', exc_info=True)
            else:
                logging.getLogger('sniffles.dev').info(f'Connected to PyCharm debugger.')

    input_ext = [f.split(".")[-1].lower() for f in config.input]

    # needed for running on osx
    if sys.platform == "darwin":
        multiprocessing.set_start_method("fork")

    if len(set(input_ext)) > 1:
        util.fatal_error_main(f"Please specify either: A single .bam/.cram file - OR - one or more .snf files - OR - a single .tsv file containing a list of .snf files and optional sample ids as input. (supplied were: {list(set(input_ext))})")

    if "bam" in input_ext or "cram" in input_ext:
        if input_ext.count("bam") + input_ext.count("cram") > 1:
            util.fatal_error_main(f"Please specify max 1 .bam//.cram file as input (got {input_ext.count('bam')})")
        config.input = config.input[0]

        if config.genotype_vcf is not None:
            config.mode = "genotype_vcf"
        else:
            config.mode = "call_sample"

        config.input_is_cram = False
        if "bam" in input_ext:
            config.input_mode = r"rb"
        elif "cram" in input_ext:
            config.input_mode = r"rc"
            config.input_is_cram = True

    elif "snf" in input_ext or "tsv" in input_ext:
        config.mode = "combine"
    else:
        util.fatal_error_main(f"Failed to determine run mode from input. Please specify either: A single .bam file - OR - one or more .snf files - OR - a single .tsv file containing a list of .snf files and optional sample ids as input. (supplied were: {list(set(input_ext))})")

    if config.mode != "call_sample" and config.snf is not None:
        util.fatal_error_main(f"--snf cannot be used with run mode {config.mode}")

    if config.vcf is None and config.snf is None:
        util.fatal_error_main("Please specify at least one of: --vcf or --snf for output (both may be used at the same time)")

    log = logging.getLogger('sniffles.main')
    if config.dev_debug_log:
        logging.getLogger().setLevel(logging.DEBUG)
    if config.dev_progress_log:
        logging.getLogger('sniffles.progress').setLevel(logging.INFO)

    if config.mode == "call_sample":
        if config.sample_id is None:
            # config.sample_id,_=os.path.splitext(os.path.basename(config.input))
            config.sample_ids_vcf = [(0, "SAMPLE")]
        else:
            config.sample_ids_vcf = [(0, config.sample_id)]

    elif config.mode == "combine":
        config.sample_id = None

        if config.combine_consensus:
            config.sample_ids_vcf = [(0, "CONSENSUS")]
        else:
            config.sample_ids_vcf = []  # Determined from .snf headers

    log.info(f"Running {config.version}, build {config.build}")
    log.info(f"  Run Mode: {config.mode}")
    log.info(f"  Start on: {config.start_date}")
    log.info(f"  Working dir: {config.workdir}")
    log.info(f"  Used command: {config.command}")
    log.info("==============================")

    rkwargs = {}  # result kwargs

    monitor = ResourceMonitor(config)
    if monitor and monitor.filename is not None:
        logging.getLogger('sniffles.resources').info(f'Logging memory usage to {monitor.filename}')

    #
    # call_sample/genotype_vcf: Open .bam file for single calling / .snf creation
    #
    contig_tandem_repeats = {}
    if config.mode == "call_sample" or config.mode == "genotype_vcf":
        log.info(f"Opening for reading: {config.input}")
        bam_in = pysam.AlignmentFile(config.input, config.input_mode)
        try:
            has_index = bam_in.check_index()
            if not has_index:
                raise ValueError
        except ValueError:
            util.fatal_error_main(f"Unable to load index for input file '{config.input}'. Please verify that your input file is sorted + indexed and that the index .bai file is valid and in the right location.")

        #
        # Load tandem repeat annotations
        #
        if config.tandem_repeats is not None:
            contig_tandem_repeats = util.load_tandem_repeats(config.tandem_repeats, config.tandem_repeat_region_pad)
            log.info(f"Opening for reading: {config.tandem_repeats} (tandem repeat annotations for {len(contig_tandem_repeats)} contigs)")

    #
    # genotype_vcf: Read SVs from VCF to be genotyped
    #
    if config.mode == "genotype_vcf":
        path, ext = os.path.splitext(config.genotype_vcf)
        ext = ext.lower()
        if ext == ".gz":
            vcf_in_handle = pysam.BGZFile(config.genotype_vcf, "rb")
        elif ext == ".vcf":
            vcf_in_handle = open(config.genotype_vcf, "r")
        else:
            util.fatal_error_main("Expected a .vcf or .vcf.gz file for genotyping using --genotype-vcf")
        vcf_in = vcf.VCF(config, vcf_in_handle)

        genotype_lineindex_order = []
        genotype_lineindex_svs = {}
        genotype_contig_svs = {}
        for svcall in vcf_in.read_svs_iter():
            if svcall.contig not in genotype_contig_svs:
                genotype_contig_svs[svcall.contig] = []
            assert (svcall.raw_vcf_line_index not in genotype_lineindex_svs)
            genotype_lineindex_order.append(svcall.raw_vcf_line_index)
            genotype_lineindex_svs[svcall.raw_vcf_line_index] = svcall
            genotype_contig_svs[svcall.contig].append(svcall)
        rkwargs['genotype_lineindex_order'] = genotype_lineindex_order
        log.info(f"Opening for reading: {config.genotype_vcf} (read {len(genotype_lineindex_svs)} SVs to be genotyped)")

    #
    # Open output files
    #
    vcf_out = None
    if config.vcf is not None:

        vcf_output_info = []
        if config.mode == "combine":
            vcf_output_info.append("multi-sample")
        else:
            vcf_output_info.append("single-sample")
        if config.sort:
            vcf_output_info.append("sorted")
        if config.vcf_output_bgz:
            vcf_output_info.append("bgzipped")
            vcf_output_info.append("tabix-indexed")

        if len(vcf_output_info) == 0:
            vcf_output_info_str = ""
        else:
            vcf_output_info_str = f"({', '.join(vcf_output_info)})"

        if os.path.exists(config.vcf) and not config.allow_overwrite:
            util.fatal_error_main(f"Output file '{config.vcf}' already exists! Use --allow-overwrite to ignore this check and overwrite.")

        if config.vcf_output_bgz:
            if not config.sort:
                util.fatal_error_main(".gz (bgzip) output is only supported with sorting enabled")
            vcf_handle = pysam.BGZFile(config.vcf, "w")
        else:
            vcf_handle = open(config.vcf, "w")

        vcf_out = vcf.VCF(config, vcf_handle)

        if config.mode == "call_sample" or config.mode == "combine":
            if config.reference is not None:
                log.info(f"Opening for reading: {config.reference}")
            vcf_out.open_reference()

        log.info(f"Opening for writing: {config.vcf} {vcf_output_info_str}")

    # SNF writing during single sample mode
    snf_out = None
    if config.snf is not None:
        log.info(f"Opening for writing: {config.snf}")

        if os.path.exists(config.snf) and not config.allow_overwrite:
            util.fatal_error_main(f"Output file '{config.snf}' already exists! Use --allow-overwrite to ignore this check and overwrite.")
        else:
            snf_out = snf.SNFile(config, open(config.snf, "wb"))

    psnf_out = None
    if psnf_name := config.dev_population_snf:
        log.info(f'Creating population SNF: {psnf_name}')

        if os.path.exists(psnf_name) and not config.allow_overwrite:
            util.fatal_error_main(f'Population SNF {psnf_name} already exists!')
        from sniffles.snfp import PopulationSNF
        psnf_out = PopulationSNF(config, open(psnf_name, 'wb'))
        rkwargs['psnf_out'] = psnf_out

    #
    # Plan multiprocessing tasks
    #
    task_id = 0
    tasks = deque()
    contigs = []
    contig_tasks_intervals = {}

    if config.mode == "call_sample" or config.mode == "genotype_vcf":
        #
        # Process .bam header
        #
        task_classes = {
            'call_sample': parallel.CallTask,
            'genotype_vcf': parallel.GenotypeTask,
        }

        total_mapped = bam_in.mapped
        if (config.threads == 1 and not config.low_memory) or config.task_count_multiplier == 0:
            task_max_reads = total_mapped
        else:
            task_max_reads = max(1, math.floor(total_mapped / (config.threads * config.task_count_multiplier)))

        if total_mapped == 0:
            # Total mapped returns 0 for CRAM files
            config.task_read_id_offset_mult = 10 ** 9
        else:
            # BAM file
            config.task_read_id_offset_mult = 10 ** math.ceil(math.log(total_mapped) + 1)

        contig_lengths = []
        contigs_with_tr_annotations = 0
        for contig in bam_in.get_index_statistics():
            if task_max_reads == 0:
                task_count = 1
            else:
                task_count = max(1, math.ceil(contig.mapped / float(task_max_reads)))
            contig_str = str(contig.contig)

            if config.contig and contig_str not in config.contig:
                continue

            if config.regions_by_contig and contig_str not in config.regions_by_contig:
                continue

            contigs.append(contig_str)
            contig_length = bam_in.get_reference_length(contig_str)
            contig_lengths.append((contig_str, contig_length))
            task_length = math.floor(contig_length / float(task_count))
            contigs_with_tr_annotations += int(contig_str in contig_tandem_repeats)
            startpos = 0

            while startpos < contig_length - 1:
                endpos = min(contig_length - 1, startpos + task_length)
                if config.genotype_vcf is not None:
                    if contig_str in genotype_contig_svs:
                        genotype_svs = [target_sv for target_sv in genotype_contig_svs[contig_str] if target_sv.pos >= startpos and target_sv.pos < endpos]
                    else:
                        genotype_svs = []
                else:
                    genotype_svs = None

                task = task_classes[config.mode](
                    id=task_id,
                    contig=contig_str,
                    start=startpos,
                    end=endpos,
                    assigned_process_id=None,
                    tandem_repeats=contig_tandem_repeats[contig_str] if contig_str in contig_tandem_repeats else None,
                    genotype_svs=genotype_svs,
                    sv_id=0,
                    config=config,
                    regions=config.regions_by_contig.get(contig_str),
                )
                tasks.append(task)
                if contig_str not in contig_tasks_intervals:
                    contig_tasks_intervals[contig_str] = []
                contig_tasks_intervals[contig_str].append((task.start, task.end, task))
                startpos += task_length
                task_id += 1
        config.contig_lengths = contig_lengths

        if contigs_with_tr_annotations < len(contig_lengths) and config.tandem_repeats != None:
            log.info(f"Info: {contigs_with_tr_annotations} of {len(contig_lengths)} contigs in the input sample have associated tandem repeat annotations.")

            if contigs_with_tr_annotations == 0:
                util.fatal_error_main("A tandem repeat annotations file was provided, but no matching annotations were found for any contig in the sample input file. Please check if the contig naming scheme in the tandem repeat annotations matches with the one in the input sample file.")

    elif config.mode == "combine":
        #
        # Process .snf headers
        #
        config.snf_input_info = []
        total_mapped = 0

        # List of filenames and optional sample label from tsv
        input_snfs_sample_ids: list[tuple[str, Optional[str]]] = []

        if len(config.input) == 1 and input_ext[0] == "tsv":
            log.info(f"Opening for reading: {config.input[0]} (loading list of .snf files and associated sample ids)")
            with open(config.input[0], "r") as tsv_handle:
                for line_index, line in enumerate(tsv_handle.readlines()):
                    line_strip = line.strip()
                    if len(line_strip) == 0 or line_strip[0] == "#":
                        continue
                    parts = line_strip.split("\t")
                    if len(parts) == 1:
                        snf_filename = parts[0]
                        sample_id = None
                    elif len(parts) == 2:
                        snf_filename = parts[0]
                        sample_id = parts[1]
                    else:
                        util.fatal_error_main(f"Invalid sample list .tsv : {config.input[0]} : Line {line_index + 1} - expected either one or two columns (first column: .snf filename, second column: optional sample id to overrule the one specified in the .snf file)")
                    input_snfs_sample_ids.append((snf_filename, sample_id))
        elif input_ext[0] == "snf":
            input_snfs_sample_ids = [(item, None) for item in config.input]
        else:
            util.fatal_error_main("Failed to determine .snf files to be combined. Please specify either one or more .snf files OR a single .tsv file as input for multi-calling.")

        log.info("The following samples will be processed in multi-calling:")
        for snf_internal_id, (input_filename, sample_id) in enumerate(input_snfs_sample_ids):
            snf_in = snf.SNFile(config, open(input_filename, "rb"), filename=input_filename)
            snf_in.read_header()
            total_mapped += snf_in.header["snf_candidate_count"]
            contig_lengths = snf_in.header["config"]["contig_lengths"]
            if not config.dev_skip_snf_validation:
                if config.snf_block_size != snf_in.header["config"]["snf_block_size"]:
                    util.fatal_error_main(f"SNF block size differs for {input_filename}")
                if config.snf_format_version != snf_in.header["config"]["snf_format_version"]:
                    util.fatal_error_main(f"SNF format version for {input_filename} is not supported")
            if sample_id is None:
                if snf_in.header["config"]["sample_id"] is not None:
                    sample_id = snf_in.header["config"]["sample_id"]
                else:
                    sample_id, _ = os.path.splitext(os.path.basename(input_filename))
            config.snf_input_info.append({"internal_id": snf_internal_id, "sample_id": sample_id, "filename": input_filename})
            reqc_info = f' (Rerunning QC)' if snf_in.reqc else ''
            snf_in.close()
            log.info(f"    {input_filename} (sample ID in output VCF='{sample_id}'{reqc_info})")

        if not config.combine_consensus:
            for info in config.snf_input_info:
                config.sample_ids_vcf.append((info["internal_id"], info["sample_id"]))

        if snfp_filename := config.combine_population:
            from sniffles.snfp import PopulationSNF
            snfp = PopulationSNF.open(snfp_filename)

        # TODO: Assure header consistency across multiple .snfs
        if to_process := (config.contig or config.regions_by_contig):
            contig_lengths = [(name, length) for name, length in contig_lengths if name in to_process]

        result_class = None
        if len(input_snfs_sample_ids) > config.combine_max_inmemory_results:
            log.info(f'Using tmp file aggregation for merge.')
            from sniffles.result import CombineResultTmpFile
            result_class = CombineResultTmpFile

            if config.sort:
                log.warning(
                    f'''Sniffles currently does not support sorting on a number of input files exceeding --combine-max-inmemory-results'''
                    f''' (currently set to {config.combine_max_inmemory_results}). Unsorted calls will be written to '''
                    f'''result-{config.run_id}-*-unsorted.part.vcf. Please use either --no-sort or increase the number of in-memory results.'''
                )
                if config.vcf_output_bgz:
                    util.fatal_error_main('BGZ files require sorting, and thus are currently not supported for the given number of input files.')

        if config.dev_population_snf:
            from sniffles.result import CombineResultTmpFilePopulationSNF
            result_class = CombineResultTmpFilePopulationSNF

        for contig_str, contig_length in contig_lengths:
            task = parallel.CombineTask(
                id=task_id,
                contig=contig_str,
                start=0,
                end=contig_length - 1,
                assigned_process_id=None,
                sv_id=0,
                config=config,
                result_class=result_class,
                regions=config.regions_by_contig.get(contig_str)
            )
            tasks.extend(task.scatter())
            if contig_str not in contig_tasks_intervals:
                contig_tasks_intervals[contig_str] = []
            contig_tasks_intervals[contig_str].append((task.start, task.end, task))
            task_id = tasks[-1].id + 1

        log.info(f"Verified headers for {len(input_snfs_sample_ids)} .snf files.")

    if config.mode != "genotype_vcf" and config.vcf is not None:
        vcf_out.write_header(contig_lengths)
    elif config.mode == "genotype_vcf":
        vcf_out.rewrite_header_genotype(vcf_in.header_str)

    #
    # Start workers
    #
    if config.threads:
        for pnum in range(config.threads):
            processes.append(parallel.SnifflesWorker(process_id=pnum, config=config, tasks=tasks, recycle_hint=monitor))
    else:
        processes.append(parallel.SnifflesParentWorker(config=config, tasks=tasks))

    if config.vcf is not None and config.sort:
        task_id_calls = {}

    log.info("")
    if config.mode == "call_sample" or config.mode == "genotype_vcf":
        if config.input_is_cram:
            # CRAM file
            log.info(f"Analyzing alignments... (progress display disabled for CRAM input)")
        else:
            log.info(f"Analyzing {total_mapped} alignments total...")
    elif config.mode == "combine":
        log.info(f"Calling SVs across {len(config.snf_input_info)} samples ({total_mapped} candidates total)...")
    log.info("")

    #
    # Distribute analysis tasks to workers and collect results
    #
    analysis_start_time = time.monotonic()

    for p in processes:
        p.start()

    finished_tasks: list[parallel.Task] = []

    while any([p.run_parent() for p in processes if p.running]):
        time.sleep(0.1)

    for p in processes:
        p.finalize()
        finished_tasks.extend(p.finished_tasks)

    log.info(f"Took {time.monotonic() - analysis_start_time:.2f}s.")
    log.info("")

    finished_tasks.sort(key=lambda task: task.id)

    for t in finished_tasks:
        t.result.emit(vcf_out=vcf_out, snf_out=snf_out, **rkwargs)

    if snf_out:
        snf_candidate_count = snf_out.write_results(config, contigs)
        snf_out.close()
        log.info(f"Wrote {snf_candidate_count} SV candidates to {config.snf} (for multi-sample calling).")

    if psnf_out:
        c = psnf_out.write_results(config, contigs)
        psnf_out.close()
        log.info(f'Wrote {c} SVs to population SNF.')

    if DEV_MONITOR_MEM:
        logging.debug(f"[DEV: Total memory usage={dbg_get_total_memory_usage_MB():.2f}MB]")

    if config.vcf is not None:
        vcf_out.close()
        if config.vcf_output_bgz:
            vcf_index_start_time = time.time()
            log.info(f"Generating index for {config.vcf}...")
            try:
                pysam.tabix_index(config.vcf, preset="vcf", force=True)
            except:
                log.exception(f'Error indexing VCF.')
            else:
                log.info(f"Indexing VCF output took {time.time() - vcf_index_start_time:.2f}s.")

    if (config.mode == "call_sample" or config.mode == "combine") and config.vcf is not None:
        log.info(f"Wrote {vcf_out.call_count} called SVs to {config.vcf} {vcf_output_info_str}")

    if monitor:
        log.debug(f'Stopping resource monitoring.')
        monitor.stop()


if __name__ == "__main__":
    processes = []

    try:
        logging.config.dictConfig({
            'version': 1,
            'formatters': {
                'default': {
                    'format': '%(asctime)s %(levelname)s %(name)s (%(process)d): %(message)s'
                }
            },
            'handlers': {
                'console': {
                    'class': 'logging.StreamHandler',
                    'formatter': 'default',
                    'stream': 'ext://sys.stdout',
                }
            },
            'loggers': {
                'sniffles.progress': {
                    'level': logging.WARNING,
                },
                'sniffles.vcf': {
                    'level': logging.INFO,
                }
            },
            'root': {
                'level': logging.INFO,
                'handlers': ['console'],
            },
            'disable_existing_loggers': False,
        })
    except (ValueError, TypeError, AttributeError, ImportError):
        logging.exception(f'Error configuring loggers.')

    try:
        Sniffles2_Main(processes)
    except (util.Sniffles2Exit, SystemExit) as exit_code:
        if len(processes):
            # Allow time for child process error messages to propagate
            print("Sniffles2Main: Shutting down workers")
            time.sleep(10)
        for proc in processes:
            try:
                proc.process.terminate()
            except:
                pass

        for proc in processes:
            try:
                proc.process.join()
            except:
                pass
        exit(exit_code.code)
    except:
        logging.getLogger('sniffles.main').exception(f'Unhandled error while running sniffles.')
        exit(1)
