#!/usr/bin/env python3
"""
GNU Rocket client with a Pythonic API.

Specify all track names on initialization:
    sync = Device("sync", ["track1", "track2", "trackGroup:one", "trackGroup:two"])
Just before your main loop, do:
    sync.start()
And inside your main loop, do:
    sync.update()
You can now access the current values of your tracks using the properties:
    sync.track1.value
    sync.track2.value
    sync.trackGroup.one.value
    sync.trackGroup.two.value
(Note how tracks arranged in groups using '.' or ':' separators are recognized
and presented as a proper hierarchy inside the sync object.)

To control a music player, make sure it's represented by a class that offers
the following methods:
    get_row() -> float   [determine current position in rows]
    set_row(float)       [seek to specific position]
    play()               [start playback]
    pause()              [pause playback]
    is_playing() -> bool [return zero if paused, nonzero if playing]
You can then specify this class instance in the Device constructor and you're
set:
    music = MyMusicPlayer(...)
    sync = Device("sync", ["track1", ...], controller=music)

By default, the client will start up as exactly this -- a client. If it fails
to reach a synctracker server (i.e. editor) on startup, it will fall back to
playback mode automatically. If the argument client=False is provided to the
Device constructor, it will start up in player mode right away.

The file format used for export and playback mode is a little bit different
from the original Rocket: It will read the old "dozens of small .track files
spewn all over the current directory" format just fine, but the native format
is a single file called <devicename>.tracks which is a ZIP file that contains
all the individual track files in compressed form.
In addition to that, when exporting tracks, it will also write a .rocket XML
file as is used by the editors -- basically, you don't need to save in the
editors any longer, Ctrl+E is enough, as it creates files for both scenarios
(player and editor). If the XML file already exists, its special per-track
settings (like the track's color in emoon's OpenGL-based editor) will be
preserved.
"""
__author__ = "Martin J. Fiedler / KeyJ^TRBL"
__version__ = "1.0"
import sys, math, time, socket, select, struct, zipfile, re

# protocol version: set to True if you want to use the old non-Qt Win32 editor
USE_OLD_PROTOCOL = False

################################################################################

Interpolators = {
    0: lambda t: 0.0,
    1: lambda t: t,
    2: lambda t: t * t * (3.0 - 2.0 * t),
    3: lambda t: t * t,
}

class Track(object):
    def __init__(self, name):
        self.name = name
        self.value = 0.0
        self.clear()

    def clear(self):
        self.row = -1
        self.data = []

    def find(self, row):
        a = 0
        b = len(self.data)
        while a < b:
            c = (a + b) >> 1
            t = self.data[c][0]
            if t < row:
                a = c + 1
            elif t > row:
                b = c
            else:
                return c
        return a - 1

    def set(self, row, value, inter=0):
        i = self.find(row)
        r = (row, value, inter)
        if i < 0:
            self.data.insert(0, r)
        elif self.data[i][0] == row:
            self.data[i] = r
        else:
            self.data.insert(i + 1, r)
        self.row = -1

    def remove(self, row):
        i = self.find(row)
        if (i < 0) or (self.data[i][0] != row):
            return
        del self.data[i]
        self.row = -1

    def get(self, row):
        if row == self.row:
            return self.value
        i = self.find(int(row))
        if i < 0:
            try:
                v = self.data[0][1]
            except IndexError:
                v = 0.0
        elif (i + 1) >= len(self.data):
            v = self.data[-1][1]
        elif not self.data[i][2]:
            v = self.data[i][1]
        else:
            ra, va, ia = self.data[i]
            rb, vb, ib = self.data[i + 1]
            t = (row - ra) / float(rb - ra)
            if ia > 1:
                try:
                    t = Interpolators[ia](t)
                except KeyError:
                    pass
            v = va + t * (vb - va)
        self.value = v
        return v

    def tobytes(self):
        return b''.join(struct.pack('<IfB', *item) for item in self.data)

    def frombytes(self, s):
        self.data = [struct.unpack('<IfB', s[x : x+9]) for x in range(0, len(s), 9)]
        self.row = -1

    def __float__(self):
        return self.value
    def __int__(self):
        return int(self.value)

################################################################################

class FakeController(object):
    """
    Controller that operates wall clock (instead of controlling some kind
    of music playback).
    """

    def __init__(self, rps=10.0):
        self.rps = rps
        self.playing = False
        self.t0 = 0
        self.r0 = 0

    def get_row(self):
        if self.playing:
            return (time.time() - self.t0) * self.rps + self.r0
        else:
            return self.r0

    def set_row(self, row):
        self.t0 = time.time()
        self.r0 = row

    def play(self):
        self.r0 = self.get_row()
        self.t0 = time.time()
        self.playing = True

    def pause(self):
        self.r0 = self.get_row()
        self.playing = False

    def is_playing(self):
        return self.playing

class TrackGroup(object):
    "dummy object used to group tracks"
    pass

class Device(object):
    """
    Class representing a synctracker device.
    """

    def __init__(self, name, tracks, client=True, controller=None, autoupdate=True, filename=None, client_host="localhost", client_port=1338):
        """
        Constructor.
        - name: the name of the tracker instance (used as a prefix for file
                names)
        - tracks: a list of strings containing track names
        - client: if False, the device will act as a player: it loads track
                  data from disk and plays that back non-interactively;
                  if True, the device will act as a client: it will connect
                  to an editor on startup; if the editor connection fails,
                  it automatically falls back to player mode
        - controller: reference to a controller object to use in client mode;
                      this object must provide the following methods:
                        get_row() -> float   [determine current position in rows]
                        set_row(float)       [seek to specific position]
                        play()               [start playback]
                        pause()              [pause playback]
                        is_playing() -> bool [return zero if paused, nonzero if playing]
                      if not specified, a FakeController with 10 rows per second
                      is used that runs against the wall clock
        - autoupdate: True if update() shall also update the values of all track
                      variables so that they are immediately available with
                      their .value properties;
                      False if .get() is going to be called explicitly by the
                      application
        - filename: the filename of the ZIP file containing the tracks
                    (default: inferred from the instance name)
        - client_host: host to connect to in client mode (default: localhost)
        - client_port: TCP port to connect to in client mode (default: 1338)
        """
        self._name = name
        self._tracks = []
        for name in tracks:
            track = Track(name)
            self._tracks.append(track)
            path = name.replace(':', '.').replace('-', '_').split('.')
            root = self
            for sub in path[:-1]:
                try:
                    root = getattr(root, sub)
                except AttributeError:
                    g = TrackGroup()
                    setattr(root, sub, g)
                    root = g
            setattr(root, path[-1], track)
        self._client = (client_host, client_port) if client else None
        self._autoupdate = autoupdate
        self._controller = controller or FakeController()
        self._filename = filename or (self._name + ".tracks")
        self._sock = None
        self._last_row = -1
        self.row = 0
        try:
            self._reconnect()
        except EnvironmentError as e:
            print("rocket.Device: failed to connect to editor -", e, "- acting as a player now", file=sys.stderr)
            self._client = False
        if not self._client:
            self._load()

    def _send(self, data):
        if not self._sock:
            raise IOError("no connection")
        try:
            #print("<SEND>", ' '.join("%02X" % c for c in data))
            self._sock.sendall(data)
        except socket.error as e:
            self._sock.close()
            self._sock = None
            raise IOError(e)

    def _recv(self, size):
        if not self._sock:
            raise IOError("no connection")
        buf = b""
        while len(buf) < size:
            try:
                block = self._sock.recv(size - len(buf))
            except socket.error as e:
                self._close()
                raise IOError(e)
            if not block:
                self._close()
                raise EOFError("connection closed")
            buf += block
            #print("<RECV>", ' '.join("%02X" % c for c in block), "..." if (len(buf) < size) else "")
        return buf

    def _close(self):
        if self._sock:
            self._sock.close()
            del self._sock
            self._sock = None

    def _reconnect(self):
        if not(self._client):
            return
        if self._sock:
            self.__close()
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            self._sock.connect(self._client)
        except socket.error as e:
            self._sock = None
            raise IOError(e)
        self._sock.settimeout(1)
        try:
            self._send(b"hello, synctracker!")
            if self._recv(12) != b"hello, demo!":
                raise IOError("invalid client response")
            n = 0
            for track in self._tracks:
                if USE_OLD_PROTOCOL:
                    self._send(b"\x02" + struct.pack('>II', n, len(track.name)) + track.name.encode('utf-8'))
                else:
                    self._send(b"\x02" + struct.pack('>I', len(track.name)) + track.name.encode('utf-8'))
                track.clear()
                n += 1
        except IOError:
            self._close()
            raise
        self._last_row = -1

    def _load(self):
        try:
            arch = zipfile.ZipFile(self._filename, 'r')
        except IOError:
            arch = None
        for track in self._tracks:
            fn = self._name + "_" + track.name + ".track"
            try:
                f = open(fn, 'rb')
                track.frombytes(f.read())
                f.close()
            except IOError:
                try:
                    track.frombytes(arch.read(fn))
                except (KeyError, AttributeError):
                    track.clear()

    def _save(self):
        arch = zipfile.ZipFile(self._filename, 'w', zipfile.ZIP_DEFLATED)
        for track in self._tracks:
            arch.writestr(self._name + "_" + track.name + ".track", track.tobytes())
        arch.close()

    def _savexml(self):
        nrows = max((t.data[-1][0] if t.data else 0) for t in self._tracks)
        tracks_tag = '<tracks rows="%d">' % max(((nrows & (~7)) + 8), 128)
        track_tags = dict((t.name, '<track name="%s">' % t.name) for t in self._tracks)
        try:
            f = open(self._name + ".rocket")
            for m in re.finditer(r'<(tracks?)([^>]*)>', f.read(), flags=re.DOTALL+re.I):
                if m.group(1).lower() == "tracks":
                    tracks_tag = m.group(0).replace('\n', ' ')
                else:
                    t = re.search(r'name="([^"]+)"', m.group(2), flags=re.I)
                    if t:
                        track_tags[t.group(1)] = m.group(0).replace('\n', ' ')
            f.close()
        except EnvironmentError:
            pass
        f = open(self._name + ".rocket", 'w')
        print('<?xml version="1.0" encoding="utf-8"?>', file=f)
        print(tracks_tag, file=f)
        for track in self._tracks:
            print('\t' + track_tags[track.name], file=f)
            for key in track.data:
                print('\t\t<key row="%d" value="%.6f" interpolation="%d" />' % key, file=f)
            print('\t</track>', file=f)
        print('</tracks>', file=f)
        f.close()

    def start(self):
        """
        Initially start playback; call this directly before your main loop.
        """
        self._controller.play()

    def update(self, row=None):
        """
        Update the internal state; this *must* be called in regular intervals
        (i.e. every frame)!
        - row: the current position, in rows (floating-point values allowed);
               if omitted, this will be taken from the controller
        The last value of row can be accessed using the 'row' property of the
        Device object.
        """
        if row is None:
            row = self._controller.get_row()
        self.row = row
        if self._autoupdate:
            for track in self._tracks:
                track.get(row)
        if not self._client:
            return row
        irow = max(0, int(row))

        while True:
            if not self._sock:
                try:
                    self._reconnect()
                except EnvironmentError as e:
                    print("rocket.Device: reconnect failed -", e, file=sys.stderr)
                    return row
            try:
                rfds, wfds, efds = select.select([self._sock], [], [self._sock], 0)
                if efds:
                    self._close()
                    continue
                if not rfds:
                    break
            except:
                self._close()
                continue

            try:
                cmd = self._recv(1)[0]
                if cmd == 0:  # SET_KEY
                    tid, row, value, inter = struct.unpack('>IIfB', self._recv(13))
                    if tid < len(self._tracks):
                        self._tracks[tid].set(row, value, inter)
                elif cmd == 1:  # DELETE_KEY
                    tid, row = struct.unpack('>II', self._recv(8))
                    if tid < len(self._tracks):
                        self._tracks[tid].remove(row)
                elif cmd == 3:  # SET_ROW
                    row = struct.unpack('>I', self._recv(4))[0]
                    self._controller.set_row(row)
                    self._last_row = irow
                elif cmd == 4:  # PAUSE
                    pause = self._recv(1)[0]
                    if pause:
                        self._controller.pause()
                    else:
                        self._controller.play()
                elif cmd == 5:  # SAVE_TRACKS
                    try:
                        self._save()
                        self._savexml()
                    except Exception as e:
                        print("rocket.Device: failed to save tracks -", e, file=sys.stderr)
                        raise
                else:
                    self._close()
            except (IOError, EOFError):
                self._close()

        if self._sock and (irow != self._last_row) and self._controller.is_playing():
            try:
                self._send(b"\x03" + struct.pack('>I', irow))
                self._last_row = irow
            except (IOError, EOFError):
                self._close()
        return row

    def is_client(self):
        "return True if running in client mode"
        return self._client
    def is_player(self):
        "return True if running in player mode"
        return not self._client

################################################################################

if __name__ == "__main__":
    if 1:  # Track editing stress test
        import random
        t = Track("unit-test")
        ref = {}
        for trial in range(10000):
            row = random.randrange(100)
            if random.randrange(2):
                value = random.random()
                ref[row] = value
                t.set(row, value)
            else:
                try:
                    del ref[row]
                except KeyError:
                    pass
                t.remove(row)
        for item, ref_row in zip(t.data, sorted(ref)):
            assert item[0] == ref_row
            assert item[1] == ref[ref_row]

    sync = Device("rocket_test", ["test", "group:a", "group:b"])
    sync.start()
    try:
        while True:
            sync.update()
            sys.stderr.write("  row = %.2f | test = %.2f | g:a = %.2f | g:b = %.2f      \r" % (sync.row, sync.test.value, sync.group.a.value, sync.group.b.value))
            sys.stderr.flush()
            time.sleep(1.0 / 60)
    except KeyboardInterrupt:
        print()
