#!/usr/bin/env python3
import sys, os, math
import ctypes, ctypes.util
from ctypes import *

BASS = None

class BASSError(RuntimeError):
    pass

################################################################################

class LibBASS(object):
    # ChannelIsActive return values
    STOPPED, PLAYING, STALLED, PAUSED = range(4)
    _prototypes = [
        ("Init",                 c_int,    c_int, c_uint32, c_uint32, c_void_p, c_void_p),
        ("ErrorGetCode",         c_int),
        ("GetVolume",            c_float),
        ("GetCPU",               c_float),
        ("StreamCreateFile",     c_uint32, c_int, c_char_p, c_uint64, c_uint64, c_uint32),
        ("StreamCreateURL",      c_uint32, c_char_p, c_uint32, c_uint32, c_void_p, c_void_p),
        ("MusicLoad",            c_uint32, c_int, c_char_p, c_uint64, c_uint32, c_uint32, c_uint32),
        ("ChannelPlay",          c_int,    c_uint32, c_int),
        ("ChannelPause",         c_int,    c_uint32),
        ("ChannelStop",          c_int,    c_uint32),
        ("ChannelGetLength",     c_int64,  c_uint32, c_uint32),
        ("ChannelGetPosition",   c_int64,  c_uint32, c_uint32),
        ("ChannelSetPosition",   c_int,    c_uint32, c_int64, c_uint32),
        ("ChannelBytes2Seconds", c_double, c_uint32, c_int64),
        ("ChannelSeconds2Bytes", c_int64,  c_uint32, c_double),
        ("ChannelIsActive",      c_uint32, c_uint32),
        ("ChannelGetAttribute",  c_int,    c_uint32, c_uint32, POINTER(c_float)),
        ("ChannelSetAttribute",  c_int,    c_uint32, c_uint32, c_float),
        ("ChannelGetLevel",      c_uint32, c_uint32),
        ("ChannelGetData",       c_int32,  c_uint32, c_void_p, c_uint32),
        ("StreamFree",           c_int,    c_uint32),
        ("MusicFree",            c_int,    c_uint32),
        ("Free",                 c_int)
    ]
    _errors = {
         1: 'MEM',       2: 'FILEOPEN',  3: 'DRIVER',    4: 'BUFLOST',
         5: 'HANDLE',    6: 'FORMAT',    7: 'POSITION',  8: 'INIT', 
         9: 'START',    14: 'ALREADY',  18: 'NOCHAN',   19: 'ILLTYPE',
        20: 'ILLPARAM', 21: 'NO3D',     22: 'NOEAX',    23: 'DEVICE',
        24: 'NOPLAY',   25: 'FREQ',     27: 'NOTFILE',  29: 'NOHW',
        31: 'EMPTY',    32: 'NONET',    33: 'CREATE',   34: 'NOFX',
        37: 'NOTAVAIL', 38: 'DECODE',   39: 'DX',       40: 'TIMEOUT',
        41: 'FILEFORM', 42: 'SPEAKER',  43: 'VERSION',  44: 'CODEC',
        45: 'ENDED',    46: 'BUSY',
    }

    def __init__(self, libpath=None):
        # system-specific adaptation
        if sys.platform == 'win32':
            libclass = WinDLL
            protoclass = WINFUNCTYPE
            prefix = ""
            suffix = ".dll"
        else:
            libclass = CDLL
            protoclass = CFUNCTYPE
            prefix = "lib"
            suffix = ".so"

        # load libraries until the one with the correct architecture has been found
        self.lib = None
        for arch_suffix in ("", "_x86", "_x64", "_armhf", "_arm"):
            path = os.path.join(libpath or '.', prefix + "bass" + arch_suffix + suffix)
            try:
                self.lib = libclass(path)
                self.fixed_point_fft = ("_arm" in arch_suffix)
                break
            except Exception as e:
                pass
        if not self.lib:
            raise ImportError("could not find BASS run-time library")

        # import functions from the library
        for proto in self._prototypes:
            name = proto[0]
            try:
                setattr(self, name, protoclass(*proto[1:])(("BASS_" + name, self.lib)))
            except AttributeError:
                raise ImportError("function %r not found in BASS run-time library" % ("BASS_" + name))

    def throw(self, where=None, force=False):
        err = self.ErrorGetCode()
        if not(err) and not(force):
            return
        msg = "BASS error %d (" % err
        if not err:
            msg += "BASS_OK"
        elif err in self._errors:
            msg += "BASS_ERROR_" + self._errors[err]
        else:
            msg += "unknown error"
        if where:
            msg += ") in BASS_%s()" % where
        else:
            msg += ")"
        raise BASSError(msg)

    def __del__(self):
        if self.lib:
            self.Free()
        else:
            del self.lib
            self.lib = None

################################################################################

def Init(freq=44100, channels=2, bits=16, device=-1, libpath=None):
    global BASS
    if not BASS:
        BASS = LibBASS(libpath)

    flags = 0
    if   bits ==  8: flags |= 1  # BASS_DEVICE_8BITS
    elif bits != 16: raise ValueError("only 8 and 16 bits per sample are supported")
    if   channels == 2: flags |= 2  # BASS_DEVICE_MONO
    elif channels != 1: raise ValueError("only mono and stereo is supported")
    if not BASS.Init(device, freq, flags, None, None):
        BASS.throw("BASS_Init", force=True)

def Done():
    global BASS
    if not BASS:
        return
    BASS.Free()
    del BASS
    BASS = None

class Track(object):
    def __init__(self, source, memory=None, offset=0, size=0, loop=False, prescan=False, bpm=125, rpb=6, rpp=64):
        assert BASS
        if memory is None:
            memory = (len(source) >= 4096)
        if memory:
            if offset:
                self.data = source[offset:]  # cut off the offset
                source = self.data
            else:
                self.data = source  # keep a reference of the data string
            size = len(source)
        elif isinstance(source, str):
            source = source.encode('utf-8')
        flags = 0
        if loop:    flags |= 4        # BASS_SAMPLE_LOOP == BASS_MUSIC_LOOP
        if prescan: flags |= 0x20000  # BASS_STREAM_PRESCAN == BASS_MUSIC_PRESCAN
        self.ismod = False
        self.bpm = bpm
        self.rpb = rpb
        self.rpp = rpp
        self.rps = self.rpb * self.bpm * (1.0 / 60)

        # try to open the source
        self.channel = BASS.StreamCreateFile(memory, source, offset, size, flags)
        if not(self.channel) and not(memory):
            self.channel = BASS.StreamCreateURL(source, offset, flags, None, None)
        if not self.channel:
            flags |= 0x4000  # BASS_MUSIC_PT1MOD
            self.channel = BASS.MusicLoad(memory, source, offset, size, flags, 0)
            self.ismod = True
        if not self.channel:
            raise IOError("could not open track source")

    def restart(self):
        if not BASS.ChannelPlay(self.channel, True):
            BASS.throw("ChannelPlay")

    def play(self):
        if not BASS.ChannelPlay(self.channel, False):
            BASS.throw("ChannelPlay")

    def pause(self):
        if not BASS.ChannelPause(self.channel):
            if BASS.ErrorGetCode() != 24:  # BASS_ERROR_NOPLAY
                BASS.throw("ChannelPause")

    def stop(self):
        if not BASS.ChannelStop(self.channel):
            BASS.throw("ChannelStop")

    def __del__(self):
        global BASS
        if not BASS:
            return
        if self.ismod:
            BASS.MusicFree(self.channel)
        else:
            BASS.StreamFree(self.channel)

    def get_state(self):
        state = BASS.ChannelIsActive(self.channel)
        return { 0: 'stopped', 1: 'playing', 2: 'stalled', 3: 'paused' }.get(state, 'unknown')
    state = property(get_state)

    def is_stopped(self): return (BASS.ChannelIsActive(self.channel) == 0)
    def is_playing(self): return (BASS.ChannelIsActive(self.channel) == 1)
    def is_stalled(self): return (BASS.ChannelIsActive(self.channel) == 2)
    def is_paused(self):  return (BASS.ChannelIsActive(self.channel) == 3)

    def get_length(self):
        pos = BASS.ChannelGetLength(self.channel, 0)  # BASS_POS_BYTE
        if pos < 0:
            return 0.0
        return BASS.ChannelBytes2Seconds(self.channel, pos)
    length = property(get_length)

    def get_pos(self):
        pos = BASS.ChannelGetPosition(self.channel, 0)  # BASS_POS_BYTE
        if pos < 0:
            return 0.0
        return BASS.ChannelBytes2Seconds(self.channel, pos)
    def set_pos(self, pos):
        pos = BASS.ChannelSeconds2Bytes(self.channel, pos)
        if not BASS.ChannelSetPosition(self.channel, pos, 0):  # BASS_POS_BYTE
            BASS.throw("ChannelSetPosition")
    pos = property(get_pos, set_pos)

    def get_order(self):
        if self.ismod:
            pos = BASS.ChannelGetPosition(self.channel, 1)  # BASS_POS_MUSIC_ORDER
            return (pos >> 16, pos & 0xFFFF)
        else:
            return (0, self.get_row())
    def set_order(self, order):
        pat, row = order
        if self.ismod:
            if not BASS.ChannelSetPosition(self.channel, (pat << 16) | row, 1):  # BASS_POS_MUSIC_ORDER
                BASS.throw("ChannelSetPosition")
        else:
            self.set_pos(row / self.rps)
    order = property(get_order, set_order)

    def get_row(self):
        if self.ismod:
            pat, row = self.get_order()
            return pat * self.rpp + row
        else:
            return self.get_pos() * self.rps
    def set_row(self, row):
        if self.ismod:
            self.set_order(row / self.rpp, row % self.rpp)
        else:
            self.set_pos(row / self.rps)
    row = property(get_row, set_row)

    def get_volume(self):
        vol = c_float()
        if not BASS.ChannelGetAttribute(self.channel, 2, byref(vol)):  # BASS_ATTRIB_VOL
            BASS.throw("ChannelGetAttribute")
        return vol.value
    def set_volume(self, vol):
        if not BASS.ChannelSetAttribute(self.channel, 2, vol):  # BASS_ATTRIB_VOL
            BASS.throw("ChannelSetAttribute")
    volume = property(get_volume, set_volume)

    def get_level(self):
        level = BASS.ChannelGetLevel(self.channel)
        if level == 0xFFFFFFFF:
            return 0
        return max(level >> 16, level & 0xFFFF) * (1.0 / 32768)
    level = property(get_level)

    def get_db(self):
        level = self.get_level()
        if level < 2.5E-05:
            return -96.0
        else:
            return 8.685889638065037 * math.log(self.get_level())
    db = property(get_db)

    def get_fft(self, size=128, db=False):
        if size > 8192:
            raise ValueError("can't get FFT results with more than 8192 points")
        real_size = 128
        fmt_id = 0x80000000  # BASS_DATA_FFT256 -> 128 points
        while size > real_size:
            real_size *= 2
            fmt_id += 1
        if BASS.fixed_point_fft:
            buf = (c_int32 * real_size)()
        else:
            buf = (c_float * real_size)()
        bytes = BASS.ChannelGetData(self.channel, cast(buf, c_void_p), fmt_id)
        if bytes < 0:
            return [0.0] * size
        assert bytes == (real_size * 4)
        buf = list(buf)
        if BASS.fixed_point_fft:
            buf = [x * (1.0 / 16777216) for x in buf]
        while len(buf) >= (size * 2):
            buf = [0.5 * (buf[x] + buf[x+1]) for x in range(0, len(buf), 2)]
        if db:
            buf = [8.685889638065037 * math.log(max(x, 1.0E-6)) for x in buf]
        return buf

################################################################################

if __name__ == "__main__":
    import time, collections

    uri = sys.argv[1]
    if os.path.isdir(uri):
        import random
        uri = os.path.join(uri, random.choice([f for f in os.listdir(uri) if os.path.splitext(f)[-1].strip('.').lower() in ("mp3", "ogg", "mod", "xm", "it", "s3m")]))

    Init()
    t = Track(uri)
    print("now playing:", uri)
    print("length:", t.length)

    fft_levels = "#Oo." + 1000 * ' '
    jitter_stats = collections.defaultdict(int)
    last = 0
    t.play()
    try:
        while True:
            time.sleep(1.0 / 60)
            pos = t.pos
            jitter_stats[int(1000.0 * (pos - last) + 0.5)] += 1
            last = pos
            vu = min(24, max(0, int((t.db + 24.0) + 0.75)))
            fft = ''.join(fft_levels[int(-0.1 * x)] for x in t.get_fft(32, True))
            sys.stderr.write("%6.2fs | r%06.1f | %s%s | %s\r" % (pos, t.row, '#' * vu, '-' * (24 - vu), fft))
            sys.stderr.flush()
    except KeyboardInterrupt:
        print()
    cpu = BASS.GetCPU()
    t.stop()
    scale = 60.0 / max(jitter_stats.values())
    for t in range(min(jitter_stats), max(jitter_stats) + 1):
        freq = jitter_stats[t]
        print("%4d | %6d | %s" % (t, freq, "#" * int(freq * scale + 0.75)))
    print("CPU load: %.2f%%" % cpu)
    Done()
