import sys
import os.path
import math
# from gdcmConfigDemo import *
from gdcmPython.core import *

class DicomEasyDict:
   def __init__(self, filename):
      self.file = gdcm.File.New()
      self.filename = None
      #print "opening dicom file",filename
      self.load( filename)
      #print "loading file dict"
      self.readDictionary()
      #print "done"

   def readDictionary(self):
      """Fill the quick access dictionary. (does not contain data keys)"""
      self.fileDict={}
      # we will soon use .. f.file.GetDataEntry( 0x20,0x07 )
      val=self.file.GetFirstEntry()
      while(val):
         # print dir(val)
         group, element = (int(val.GetKey().GetGroup()),int(val.GetKey().GetElement()))
         self.fileDict[(group, element)] = val
         val=self.file.GetNextEntry()
      val=None

   def load(self, filename):
      self.filename = filename
      self.file.SetFileName( self.filename )
      self.file.Load()
      if not self.file.IsReadable():
         raise RuntimeError, "Cannot read the file %s." % self.filename
      
   def getGradientVector( self ):
      dirx = self.fileDict[(0x0019,0x10bb)]
      diry = self.fileDict[(0x0019,0x10bc)]
      dirz = self.fileDict[(0x0019,0x10bd)]
      v = [ float(dir.GetString()) for dir in [dirx,diry,dirz] ]
      return v

   def getOrigin( self ):
      return [self.file.GetXOrigin(), self.file.GetYOrigin(), self.file.GetZOrigin() ]

   def getSize( self ):
      return [ self.file.GetXSize(), self.file.GetYSize () ]

   def getSpacing( self ):
      return [ self.file.GetXSpacing(), self.file.GetYSpacing() ]

   def getImageOrientationPatient( self ):
      key = self.fileDict[(0x0020,0x0037)]
      # FIXME: if this fails, we should try 0x0035 as well!
      return [ key.GetValue(x) for x in  range( key.GetValueCount() ) ]


def vecLength( f ):
   sqsum = 0.0
   for i in f:
      sqsum += i*i
      
   return math.sqrt( sqsum )
   

import glob

def getDicomFileList( baseDir ):
   mask = os.path.join(baseDir,'*.dcm')
   return glob.glob(mask)

files = []
for f in getDicomFileList( '/data/emc/unpack/DTI' ): # sys.argv[1] )[:3]:
   files += [ DicomEasyDict( f ) ]
   sys.stdout.write(".")
   sys.stdout.flush()

print "done loading"

#for f in files:
#   print f.getOrigin(), f.getImageOrientationPatient(), f.getSpacing(), f.getSize()

def crossProduct( a,b ):
   return [ a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2] , a[0] * b[1] - a[1] * b[0] ]

from math import *
import Numeric

DTIGradients=[]

print "extracting raw data..."
for f in files:
   # print f.getOrigin(), f.getImageOrientationPatient()
   origin = f.getOrigin()
   spacing = f.getSpacing()
   size = f.getSize()
   trans = f.getImageOrientationPatient()
   sizeX = size[0] * spacing[0]
   sizeY = size[1] * spacing[1]

   rDirection = trans[0:3]
   aDirection = trans[3:6]
   hDirection =crossProduct( rDirection, aDirection ) # another fiddlethingy
   
   # print "rdir:",rDirection, "adir:",aDirection, "hdir:",hDirection
   f.corner1 = origin
   f.corner2 = [org + trX*sizeX for org,trX in zip(origin, trans[0:3]) ]
   f.corner3 = [org + trX*sizeX + trY*sizeY for org,trX,trY in zip(origin, trans[0:3], trans[3:6] ) ]
   f.corner4 = [org + trY*sizeY for org,trY in zip(origin, trans[3:6]) ]
   f.distance = Numeric.dot( hDirection, origin )
   print "file",f.filename,"direction",f.getGradientVector()
   DTIGradients += [ tuple(f.getGradientVector()) ]
   #print dir(f.file), f.file.GetDataEntry( 0x20,0x07 )
   fh = gdcm.FileHelper_New(  f.file  )
   # print dir(fh)
   fh.GetImageData()
   f.rawfile = 'slice_%s.raw' % f.file.GetImageNumber()
   fh.WriteRawData( f.rawfile )
   #print f.file.GetImageNumber()

DTIGradients = set(DTIGradients)
print "Set of gradient directions:", DTIGradients
if len(DTIGradients)>1:
   print "DTI data assumed. directions = 1 +",len(DTIGradients)-1
   print "sorting slices by diffusion gradient..."
   GradSlices = {}
   for grad in DTIGradients:
      GradSlices[grad] = []
      
   for f in files:
      GradSlices[ tuple(f.getGradientVector()) ] += [f]

   i=0
   #zeroVector = None
   for grad in GradSlices:
      i+=1
      print "Direction #%d [%+0.3f  %+0.3f  %+0.3f] has %d slices" % (i , grad[0],grad[1],grad[2], len(GradSlices[grad]))
      #if vecLength( grad ) < 0.01:
      #   print "removing the zerovector (b0)"
      #   zeroVector = grad
         
   #if zeroVector is not None:
   #   zeroSlice = GradSlices[zeroVector]
   #   # del GradSlices[zeroVector]

   slicecounts = ( [len(GradSlices[x]) for x in GradSlices] )
   smin = min(slicecounts)
   smax = max(slicecounts)
   if smin != smax:
      print "This dataset does not seem to be complete. Some of the diffusion directions have",smin,"slices while others have",smax
      raise RuntimeError,"Dataset is not complete"

   print "This dataset seems to be fine. trying to build volumes"


   nrrdNames = {}
   i=0
   for grad in GradSlices:
      i+=1
      rawvolume = ""
      zres = 0
      
      GradSlices[grad].sort( key=lambda x:x.distance ) # sort by distance
      
      for slice in GradSlices[grad]:
         rawvolume += file( slice.rawfile, 'rb' ).read()
         xres,yres = slice.getSize()
         zres += 1

      print "extracted volume of size %dx%dx%d" % (xres,yres,zres)
      if xres*yres*zres*2 != len(rawvolume):
         print "volume size should be %d while it really is %d" % (xres*yres*zres*2 , len(rawvolume) )
         raise RuntimeError, "Volume does not seem to be the right size"

      dtiname = 'dti_%02d.raw' % i
      dtinrrdname = 'dti_%02d.nrrd' % i
      vol = file( dtiname,'wb' )
      vol.write( rawvolume )
      vol.close()

      zspace = GradSlices[grad][1].distance-GradSlices[grad][0].distance
      xspace,yspace = GradSlices[grad][0].getSpacing()

      # call teem here to build the nrrd files.
      execstring = "unu make -i '%s' -s %d %d %d -sp %f %f %f -t ushort -o '%s'" % (dtiname, xres, yres, zres, xspace, yspace, zspace, dtinrrdname)
      print "Calling external command: ",execstring
      status = os.system( execstring )
      if status != 0:
         print "The external invocation of TEEM's utility UNU has failed with error code",status
         raise RuntimeError, "TEEM/UNU Failure"
      nrrdNames[grad] = dtinrrdname

   print "Building b-matrix"
   bvecsname = 'bvectors.txt'
   bmatnrrdname = 'bmatrix.nrrd'
   bvecs = file(bvecsname ,'w')
   first = True
   for grad in GradSlices:
      #       lgrad = [
      #          rDirection[0] * grad[0] + grad[1] * aDirection[0] + grad[2] * hDirection[0],
      #          rDirection[1] * grad[0] + grad[1] * aDirection[1] + grad[2] * hDirection[1],
      #          rDirection[2] * grad[0] + grad[1] * aDirection[2] + grad[2] * hDirection[2] ]
      #if grad == zeroVector:
      #   print "Skipping B0 vector in gradient matrix generation"
      #   continue
      

      # this should be the correct one, i suppose
      lgrad = [
          (rDirection[0] * grad[0] + grad[1] * rDirection[1] + grad[2] * rDirection[2]),
          (aDirection[0] * grad[0] + grad[1] * aDirection[1] + grad[2] * aDirection[2]),
          (hDirection[0] * grad[0] + grad[1] * hDirection[1] + grad[2] * hDirection[2]) ] 

      # this looks like it kinda works
      lgrad = [
          (rDirection[0] * grad[0] + grad[1] * rDirection[1] + grad[2] * rDirection[2]),
          (aDirection[0] * grad[0] + grad[1] * aDirection[1] + grad[2] * aDirection[2]),
          -(hDirection[0] * grad[0] + grad[1] * hDirection[1] + grad[2] * hDirection[2]) ] 

      # lgrad = [ grad[0], grad[1], -grad[2] ]
      # lgrad = [ grad[0], grad[1], -grad[2] ]
      print "transformed",grad,"into",lgrad
      print >>bvecs, "%f %f %f" % (lgrad[0],lgrad[1],lgrad[2])
      
   bvecs.close()

   print "Executing tend to convert b vectors to b matrix"
   execstring = "tend bmat -i '%s' -o '%s'" % ( bvecsname, bmatnrrdname )
   status = os.system( execstring )
   if status != 0:
      print "The invocation of TEND has failed with error code",status
      raise RuntimeError, "TEEM/TEND failure in building b matrix"

   distortionCorrectionEnabled = False
   
   if distortionCorrectionEnabled:
      print "Performing distortion correction"
      execstring = "tend epireg -i "
      for nrrd in nrrdNames:
         execstring += "'"+nrrdNames[nrrd]+"' "

      nrrdNames = [ "ep_"+x for x in nrrdNames ]
      
      execstring += " -g bvectors.txt -o ep_dti_"
      print "Executing: ",execstring
      status = os.system( execstring )
      if status != 0:
         print "The invocation of TEND has failed with error code",status
         raise RuntimeError, "TEEM/TEND failure in distortion correction of the tensors"
   
   print "All directions are processed. Proceeding to run tensor estimation code"

   execstring = "tend estim -i "
   for nrrd in nrrdNames:
      execstring += "'"+nrrdNames[nrrd]+"' "

   execstring += " -B '%s' -knownB0 false -o tensors.nrrd" % bmatnrrdname

   print "Executing: ",execstring
   status = os.system( execstring )
   if status != 0:
      print "The invocation of TEND has failed with error code",status
      raise RuntimeError, "TEEM/TEND failure in estimating tensors"

   print "Writing .transform file"
   f = file("tensors.nrrd.transform", "w")
   baseSlice = GradSlices.values()[0][0]
   origin =  baseSlice.getOrigin()
   print >>f, "pos %f %f %f" % tuple(origin)
   print >>f, "uvec %f %f %f" % tuple(rDirection)
   print >>f, "vvec %f %f %f" % tuple(aDirection)
   print >>f, "wec %f %f %f" % tuple(hDirection)

   f.close()
   
      #unu make -i dtiname -s xres yres zres
# files.sort( key=lambda x:x.distance )


