#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import gzip
from argparse import ArgumentParser
from collections import namedtuple
from struct import pack, unpack_from


# map from mario paint versions to offsets of song data in RAM
RAM_OFFSETS = {
    'jp/na': 0x9E4,
    'eu': 0x9E4,
    'proto': 0x9DE,
}

# map from IT pitch values to mario paint pitch values
NOTEMAP = {
    45: 0,   # A-3
    47: 1,   # B-3
    48: 2,   # C-4
    50: 3,   # D-4
    52: 4,   # E-4
    53: 5,   # F-4
    55: 6,   # G-4
    57: 7,   # A-4
    59: 8,   # B-4
    60: 9,   # C-5
    62: 10,  # D-5
    64: 11,  # E-5
    65: 12,  # F-5
    67: 13,  # G-5
    69: 14,  # A-5
    71: 15,  # B-5
}


# State = namedtuple('State', ['pre_data', 'song_data', 'end', 'loop',
#                              'unused1', 'tempo', 'unused2', 'time_sig',
#                              'unused3', 'undo_data', 'post_data'])
# mutable struct instead… *sigh*
class State():
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

# immutable struct
Cell = namedtuple('Cell', ['note', 'inst', 'vol', 'cmd', 'cmdval'])


# read IT order list
def read_orders(data):
    ordnum = unpack_from('H', data, 0x20)[0]
    return unpack_from('B' * ordnum, data, 0xC0)


# determine IT pattern offsets based on header data
def pattern_offsets(data):
    ordnum, insnum, smpnum, patnum = unpack_from('HHHH', data, 0x20)
    offset = 0xC0 + ordnum + insnum * 4 + smpnum * 4
    return unpack_from('I' * patnum, data, offset)


# read one pattern from an IT module
def read_pattern(data, offset):
    _, rows = unpack_from('HH', data, offset)
    offset += 8

    prev_maskvar, prev_note, prev_ins = ([0] * 64 for i in range(3))
    prev_vol, prev_cmd, prev_cmdval = ([0] * 64 for i in range(3))
    cells = [[None for y in range(rows)] for x in range(4)]

    for row in range(rows):
        while True:
            channelvariable = unpack_from('B', data, offset)[0]
            offset += 1
            if channelvariable == 0:
                break  # end of row
            channel = (channelvariable - 1) & 63
            if channelvariable & 128:
                maskvar = unpack_from('B', data, offset)[0]
                offset += 1
            else:
                maskvar = prev_maskvar[channel]
            prev_maskvar[channel] = maskvar

            if maskvar & 1:
                note = unpack_from('B', data, offset)[0]
                prev_note[channel] = note
                offset += 1
            else:
                note = None

            if maskvar & 2:
                ins = unpack_from('B', data, offset)[0]
                prev_ins[channel] = ins
                offset += 1
            else:
                ins = None

            if maskvar & 4:
                vol = unpack_from('B', data, offset)[0]
                prev_vol[channel] = vol
                offset += 1
            else:
                vol = None

            if maskvar & 8:
                cmd, cmdval = unpack_from('BB', data, offset)
                prev_cmd[channel], prev_cmdval[channel] = cmd, cmdval
                offset += 2
            else:
                cmd, cmdval = None, None

            if maskvar & 16:
                note = prev_note[channel]
            if maskvar & 32:
                ins = prev_ins[channel]
            if maskvar & 64:
                vol = prev_vol[channel]
            if maskvar & 128:
                cmd = prev_cmd[channel]
                cmdval = prev_cmdval[channel]

            if channel < 4:
                cells[channel][row] = Cell(note, ins, vol, cmd, cmdval)

    return cells


# read all patterns in an IT module
def read_patterns(data):
    offsets = pattern_offsets(data)
    patterns = []
    for offset in offsets:
        pattern = read_pattern(data, offset)
        patterns.append(pattern)
    return tuple(patterns)


# read song data from an IT module
def load_module(f):
    data = f.read()
    if data[:4].decode('ascii') != 'IMPM':
        raise ValueError('invalid IT module')
    speed, tempo = unpack_from('BB', data, 0x32)
    orders = read_orders(data)
    patterns = read_patterns(data)
    return speed, tempo, orders, patterns


# return offset of ram block in an sst
def find_ram(buf):
    index = 0xE  # after initial signature
    while True:
        label = buf[index:index+3]
        if label == b'RAM':
            index += 11
            break
        index += 11 + int(buf[index+4:index+10])  # skip block
    return index

# read an sst from a file (should be opened as gzip)
def load_savestate(f, version):
    buf = f.read()
    offset = find_ram(buf) + RAM_OFFSETS[version]
    return State(
        pre_data=buf[:offset],
        song_data=unpack_from('576B', buf, offset),
        end=unpack_from('H', buf, offset+576)[0],
        loop=unpack_from('H', buf, offset+577)[0],  # need both bytes??
        unused1=buf[offset+580:offset+585],
        tempo=unpack_from('B', buf, offset+585)[0],
        unused2=buf[offset+586:offset+590],
        time_sig=unpack_from('B', buf, offset+590)[0],
        unused3=buf[offset+591:offset+600],
        undo_data=unpack_from('592B', buf, offset+600),
        post_data=buf[offset+1192:],
    )


# write an sst to a file (should be opened as gzip)
def dump_savestate(state, f):
    f.write(state.pre_data)
    f.write(pack('576B', *state.song_data))
    f.write(pack('HH', state.end, state.loop))
    f.write(state.unused1)
    f.write(pack('B', state.tempo))
    f.write(state.unused2)
    f.write(pack('B', state.time_sig))
    f.write(state.unused3)
    f.write(pack('592B', *state.undo_data))
    f.write(state.post_data)


# main routine if invoked as script
if __name__ == '__main__':
    # parse command-line args
    parser = ArgumentParser(
        description='Inject Impulse Tracker pattern data into a SNES9x ' +
                    'Mario Paint savestate.')
    parser.add_argument('source', help='source IT module')
    parser.add_argument('target', help='target SST file')
    parser.add_argument('-r', choices=('jp/us', 'eu', 'proto'),
                        dest='rversion', default='jp/us',
                        help='release version of target ROM (default jp/us)')
    parser.add_argument('--version', action='version', version='%(prog)s 0.3')
    args = parser.parse_args()

    # read savestate
    with gzip.open(args.target, 'rb') as f:
        state = load_savestate(f, args.rversion)

    # read IT pattern data
    with open(args.source, 'rb') as f:
        speed, tempo, orders, patterns = load_module(f)

    # construct new song data from IT patterns
    max_beats = 96  # 96 is the normal mario paint maximum
    song_data = [255, 223] * 3 * max_beats
    beat = 0
    for order in orders:
        if beat >= 96 or order == 255:  # 255 is end marker or something
            break
        pattern = patterns[order]
        for row in range(len(pattern[0])):
            for channel in range(3):
                cell = pattern[channel][row]
                if cell is None or cell.note is None or cell.inst is None:
                    continue
                song_data[beat*6+channel*2] = \
                    NOTEMAP[cell.note] if cell.note else 255
                # change blank instrument to rocket icon
                inst = 0x20 if cell.inst == 0x10 else cell.inst
                # mario paint insts start at 1, IT insts start at 1
                song_data[beat*6+channel*2+1] = inst-1 if inst else 223
            beat += 1
            if beat >= 96:
                break
    state.song_data = song_data[:576]  # 576 is normal size of song data

    # set other params based on IT data
    state.end = 16 + 8 * beat
    hz = 50 if args.rversion == 'eu' else 60
    state.tempo = tempo * 24 * 13//3 // speed // hz  # 13/3 is arbitrary
    state.loop = 0

    # write modified savestate
    with gzip.open(args.target, 'wb') as f:
        dump_savestate(state, f)
