#!/usr/bin/env python
# Copyright 2025 Irwin Jungreis
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The maf files for the 447 mammal alignment include some blocks in which a single region in
    a non-reference species was split into more than one line.
Given a maf file, find cases where this happens in order to find how prevalent it is.
If requested using the -allowGaps option, also look for cases where the two lines have
    a gap in coordinates between them up to a certain size. In these cases, to fix things
    the block would have to be split.
"""
from __future__ import division, print_function
import sys, os, itertools
from collections import OrderedDict

def main() :
    maxGap = get_associated_arg_int('-allowGaps', 0, remove = True)
    assert_num_args(2, 'MafFile OutputFile [-allowGaps MAX_GAP]', exact = True)
    mafName, outName = sys.argv[1 : 3]
    with myopen(mafName, 'rt') as aliFile :
        with myopen(outName, 'wt') as outFile :
            for block in _iter_aliblocks(aliFile) :
                refSeq = block.seqs[0]
                srcDict = multidict((seq.src, seq) for seq in block.seqs)
                for src, seqs in srcDict.items() :
                    seqs.sort(key = lambda seq : seq.start)

                    def continueChunkFcn(seq1, seq2) :
                        return (seq2.start - maxGap <= seq1.start + seq1.size <= seq2.start
                                and
                                _index_of_last(seq1.bases()) < _index_of_first(seq2.bases()))

                    chunks = chunkify(seqs, continueChunkFcn)
                    for chunk in chunks :
                        if len(chunk) == 1 :
                            continue
                        print(refSeq.src, refSeq.start, refSeq.size, refSeq.strand,
                              refSeq.bases(), sep = '\t', file = outFile)
                        for seq in chunk :
                            print(seq.src, seq.start, seq.size, seq.strand,
                                  seq.bases(), sep = '\t', file = outFile)
                        print(file = outFile)

def _index_of_first(s) :
    """
    Return the index in s of the first base that is not '-'.
    """
    return len(s) - len(s.lstrip('-'))

def _index_of_last(s) :
    """
    Return the index in s of the last base that is not '-'.
    """
    return len(s.rstrip('-')) - 1


def _iter_aliblocks(aliFile, startFilePos = None, endFilePos = None) :
    # Parse the alignment file ignoring all data except a and s lines
    # Yield _AliBlocks.
    # If startFilePos is not None, start at that seek position in the file.
    # If endFilePos is not None, only include blocks that are entirely before endFilePos.

    if startFilePos is not None :
        aliFile.seek(startFilePos)
    curBlock = None
    lineCount = -1
    while True :
        lineCount += 1
        lineStartPos = aliFile.tell()
        if endFilePos is not None and lineStartPos >= endFilePos :
            break
        line = aliFile.readline()
        if line == '' :
            break
        if line[0] == '#' :
            continue
        strs = line.split()
        if len(strs) == 0 :
            if curBlock is not None :
                curBlock.nextSeekPos = lineStartPos
                yield curBlock
            curBlock = None
            continue
        if strs[0] == 'a' :
            # Starting a new block. Yield the previous one.
            if curBlock is not None :
                curBlock.nextSeekPos = lineStartPos
                yield curBlock

            curBlock = _AliBlock()
            curBlock.seqs = []
            curBlock.refStart = curBlock.refSize = None
            curBlock.score = None
            for st in strs :
                if st[0 :6] == 'score=' :
                    curBlock.score = st[6 :]  # String, not float, since sometimes NA
            curBlock.seekPos = lineStartPos
        elif strs[0] == 's' :
            seq = _AliSequence(src = strs[1], start = int(strs[2]), size = int(strs[3]),
                               strand = strs[4], srcSize = int(strs[5]), bases = strs[6])
            # curBlock can't be None unless there is an 's' line with no 'a' line before
            # it, which should have been caught during indexing, but check anyway
            if curBlock is None :
                raise AlignmentCacheError('s line without preceding a line at line %d: %s'
                                          % (lineCount, line))
            curBlock.seqs.append(seq)
            if curBlock.refStart == None :
                curBlock.refStart = seq.start
                curBlock.refSize = seq.size
        else :
            pass  # Ignore i, e, q, and any other lines
    # Yield final block
    if curBlock is not None :
        curBlock.nextSeekPos = lineStartPos
        yield curBlock

class _AliBlock(object) :
    # Note: positions in _AliBlock members are 0-based, whereas gff/gtf files are 1-based.
    # However, most _AliSequence functions accept and return to 1-based coordinates.
    def __init__(self, score = None, refStart = None, refSize = None, seqs = None,
                 seekPos = None, nextSeekPos = None) :
        if seqs is None :
            seqs = [] # Don't put mutable empty list as default arg!
        self.score       = score    # Kept as string because sometimes it's NA.
        self.refStart    = refStart # 0-based start position in reference (i.e., 1st) src
        self.refSize     = refSize  # length in the reference src (not including gaps)
        self.seqs        = seqs     # list of _AliSequence
        self.seekPos     = seekPos     # offset/virtual offset of first position of block
        self.nextSeekPos = nextSeekPos # offset/virtual offset of 1st position after block
    def __str__(self) :
        return '_AliBlock(refStart = %d, refSize = %d, %d seqs)' % (
            self.refStart, self.refSize, len(self.seqs))
    __repr__ =  __str__

class _AliSequence(object) :
    def __init__(self, src, start, size, strand, srcSize, bases) :
        assert size == len(bases.replace('-', '')), (src, start, size,
                                                     len(bases.replace('-', '')))
        self.src = src         # Assembly.Chrom
        self.start = start     # Start of aligned region in source sequence (0-based). If
                               #   strand is '-' this is relative to reverse-complemented
                               #   source sequence.
        self.size = size       # Number of non-dash characters in bases.
        self.strand = strand   # '+' or '-'
        self.srcSize = srcSize # Size of entire source sequence, not just aligned parts
        self._basesNeverChangeThis = bases # Changing this will invalidate cachedIndPair
        self._cachedIndPair = None # Cache pair (nonGapInd, ind) to speed up coord_to_ind
    def bases(self) :
        return self._basesNeverChangeThis
    def get_start_and_end_pos(self) :
        """ Return the 1-based positions of the first and last bases of the sequence
            relative to the plus strand."""
        if self.strand == '+' :
            startPos = self.start + 1
            endPos = self.start + self.size
        else :
            startPos = self.srcSize - self.start - self.size + 1
            endPos = self.srcSize - self.start
        return startPos, endPos
    def assem_and_chrom(self) :
        pos = self.src.find('.')
        if pos < 0 :
            return self.src, None
        else :
            return self.src[: pos], self.src[pos + 1 :]
    def assem(self) :
        return self.assem_and_chrom()[0]
    def chrom(self) :
        return self.assem_and_chrom()[1]
    def coord_to_ind(self, coord) :
        """Return index in bases of the non-gap base whose 1-based coordinate is coord."""

        # nonGapInd is the index within the non-gap bases, e.g.,
        #     bases:       --C-GT-TT-
        #     ind:         0123456789
        #     nonGapInd:     0 12 34
        nonGapInd = coord - self.start - 1
        assert 0 <= nonGapInd < self.size, (nonGapInd, self.size)
        bases = self.bases()
        if len(bases) == self.size :
            # No gaps, so ind and nonGapInd are the same.
            # _make_aliSeg_row explains why this optimization is needed
            return nonGapInd

        # Determine where to begin: closest to nonGapInd of start, end, or cached position
        candidateStarts = [(0, 0), (self.size - 1, len(bases) - 1)]
        if self._cachedIndPair is not None :
            candidateStarts.append(self._cachedIndPair)
        ngind, ind = min(candidateStarts, key = lambda pair : abs(nonGapInd - pair[0]))
        step = 1 if ngind < nonGapInd or ngind == nonGapInd == 0 else -1
        while ngind != nonGapInd :
            if isnt_gap(bases[ind]) :
                ngind += step
            ind += step
        while is_gap(bases[ind]) : # Trim gaps off start or end
            ind += step
        """
        Note: we are relying on this _AliSequence being the one in the cached block.
             If the cached block were copied to a working copy then setting cachedIndPair 
             in the working copy wouldn't do much good.
        """
        self._cachedIndPair = (nonGapInd, ind)

        return ind
    def __str__(self) :
        return '_AliSequence(%s:%d-%d, %s, %d bases, %s (len %d))' % (
            self.src, self.get_start_and_end_pos()[0], self.get_start_and_end_pos()[1],
            self.strand, self.size,
            self.bases() if len(self.bases()) < 12 else
            self.bases()[:6] + '...' + self.bases()[-3:],
            len(self.bases()))
    __repr__ =  __str__

class AlignmentCacheError(Exception) : pass

GapBases = '.-|'

def is_gap(base) :
    return base in GapBases

def isnt_gap(base) :
    return not is_gap(base)

def myopen(path, *pArgs, **kArgs) :
    """
    Open the file, after changing the path to an absolute path.
    If file extension is .gz, use gzip  to open the file (but for .gz
        make text mode the default for consistency with open, even though gzip makes
        binary mode the default).
    If not writing and file doesn't exist, try to find the file by adding or removing .gz
        and open that one instead.
    """
    path = get_absolute_path(path)
    writing = (len(pArgs) > 0 and 'w' in pArgs[0]) or ('w' in kArgs.get('mode', ''))
    gzExt  = b'.gz'  if isinstance(path, bytes) else '.gz'
    if not writing and not file_exists(path) :
        if file_exists(path + gzExt) :
            path = path + gzExt
        elif path.endswith(gzExt) and file_exists(path[:-3]) :
            path = path[:-3]
    if path.endswith(gzExt) :
        import gzip
        # gzip.open defaults to binary. Change default to text for consistency with open.
        mode = pArgs[0] if len(pArgs) > 0 else kArgs.setdefault('mode', 'rt')
        if 'b' not in mode and 't' not in mode :
            if len(pArgs) > 0 :
                pArgs = (pArgs[0] + 't',) + pArgs[1:]
            else :
                kArgs['mode'] += 't'
        return gzip.open(path, *pArgs, **kArgs)
    else :
        return open(path, *pArgs, **kArgs)

def get_absolute_path(path) :
    "Convert ., .., and ~"
    return os.path.abspath(os.path.expanduser(path))

def file_exists(fileName, ignoreGz = False) :
    """
    If ignoreGz, return true if fileName.gz exists even if fileName doesn't or if
        fileName ends in .gz and fileName without the .gz exists.
    """
    fileName = get_absolute_path(fileName)
    fileExists = os.path.exists(fileName)
    if fileExists or not ignoreGz :
        return fileExists
    gzExt  = b'.gz'  if isinstance(fileName, bytes) else '.gz'
    if fileName.endswith(gzExt) :
        return os.path.exists(fileName[:-3])
    else :
        return os.path.exists(fileName + gzExt)

class multidict(OrderedDict) :
    # A dictionary in which each key is associated with a list of items.
    def __init__(self, itemIterator = None) :
        super(multidict, self).__init__()
        if itemIterator != None :
            self.update(itemIterator)
    def __eq__(self, other) :
        """Want order-independent comparison, but OrderedDict is order-dependent, so
           go to parent of OrderedDict, which is dict"""
        return super(OrderedDict, self).__eq__(other)
    def add(self, key, value) :
        self.setdefault(key, [])
        self[key].append(value)
    def update(self, itemIterator) : # Add all key-item pairs in the iterator.
        for item in itemIterator :
            self.add(*item)
    def __getitem__(self, key) :
        return self.get(key)
    def get(self, key, default = []) :
        # Return list for key, or default if key not present.
        # Differs from parent method in that default is [] if not specified.
        return super(multidict, self).get(key, default)
    def to_dict(self) :
        # Return a dictionary with the same items, which can then be used without
        # needing the definition of multidict.
        return dict(self.items())
    # To delete a key and its list, use del self[key].

def chunkify(iterable, continueChunk) :
    """
    Break up the iterable into chunks, given a function that species whether to continue
        a chunk from one element to the next. (Search terms: cluster, clump)
    continueChunk = lambda prevElt, thisElt : True or False
    Yield lists of consecutive elements.
    """
    curChunk = []
    for elt in iterable :
        if curChunk == [] or continueChunk(curChunk[-1], elt) :
            curChunk.append(elt)
        else :
            yield curChunk
            curChunk = [elt]
    if curChunk != [] :
        yield curChunk

def assert_num_args(numArgsRequired, correctUsageString, exact = False,
                    maxNumAllowed = None, silent = False, forceFailure = False) :
    """
    Check that number of arguments is at least numArgsRequired and if not then print
        usage string (if not silent) and stop.
    If maxNumAllowed is not None, require that there are at most that many arguments.
    If exact is True, require that there are exactly numArgsRequired arguments (equivalent
        to setting maxNumAllowed to numArgsRequired).
    If forceFailure, always print usage string and stop even if there are the correct
        number of arguments. This is useful if something else is wrong with arguments.
    """
    numArgs = len(sys.argv) - 1
    if (numArgs < numArgsRequired or
            (exact and numArgs > numArgsRequired) or
            (maxNumAllowed is not None and numArgs > maxNumAllowed) or
            forceFailure) :
        if not silent :
            print('Usage: ' + correctUsageString, file = sys.stderr)
        raise SystemExit(1)

def _get_arg_pos(string, silent = False) :
    """
    Return the position of string in sys.argv, or -1 if not present.
    If not silent, print a message that the argument is being used.
    """
    for ii, arg in enumerate(sys.argv) :
        if string != arg :
            continue
        if not silent :
            print('Using command line argument: ' + string, file = sys.stderr)
        return ii
    return -1

def get_associated_arg(string, default = '', remove = False, silent = False,
                       numArgs = 1, mappers = None) :
    """
    Look for an argument named string, and return the subsequent argument(s) as
        a single value if numArgs == 1, otherwise a list.
    If string isn't there, return default
    If string is there but is last, alert the user and stop.
    If remove is true, remove the string and associated argument from sys.argv
    If mappers is not None, it is a function that maps arguments (e.g., int) or a
        sequence of such functions, which will be applied to (non-default) output(s).
    If not silent, print a message that the argument is being used.
    """
    pos = _get_arg_pos(string, silent)
    if pos < 0 :
        return default
    if pos >= len(sys.argv) - numArgs :
        if numArgs == 1 :
            print('Argument %s requires a value.' % string, file = sys.stderr)
        else :
            print('Argument %s requires %d values.' % (string, numArgs), file = sys.stderr)
        raise SystemExit(1)
    if not silent :
        if numArgs == 1 :
            print('   with value %s' % sys.argv[pos + 1], file = sys.stderr)
        else :
            print('   with values %s' % sys.argv[pos + 1 : pos + numArgs + 1],
                  file = sys.stderr)
    assocArgs = list(sys.argv[pos + 1 :  pos + numArgs + 1])
    if remove :
        del sys.argv[pos : pos + numArgs + 1]
    if mappers is not None :
        # If mappers is not iterable, replace it with [mappers]
        try :
            iter(mappers)
        except TypeError :
            mappers = [mappers]

        for argInd, mapper in zip(range(len(assocArgs)), itertools.cycle(mappers)) :
            if mapper is not None :
                assocArgs[argInd] = mapper(assocArgs[argInd])
    return assocArgs[0] if numArgs == 1 else assocArgs

def get_associated_arg_int(*pArgs, **kArgs) :
    kArgs['mappers'] = int
    return get_associated_arg(*pArgs, **kArgs)


if __name__ == '__main__' :
    main()