#pragma once
#include <windows.h>
#include <math.h>


void dprintf (wchar_t * format, ...)
//////////////////////////////////
{
  static wchar_t dbgbuf_w[2048];
	va_list arg;
	va_start(arg, format);
	vswprintf(dbgbuf_w, format, arg);
	va_end(arg);
  OutputDebugString(dbgbuf_w);
}


// the matrix code is from the internet, but don't ask me where I got it from.
// I don't remember anymore.

class matrix
{

  int     maxsize;     
  int     actualsize;  
  float*  data;      

  void allocate()   
  ///////////////
  {
    delete[] data;
    data = new float [maxsize*maxsize];
  };

  matrix  () {}; 

  matrix  (int newmaxsize) 
  /////////////////////////
  {
    matrix(newmaxsize,newmaxsize);
  };

public:
  
  matrix(int newmaxsize, int newactualsize)  
  //////////////////////////////////////////
  {
    if (newmaxsize <= 0) newmaxsize = 5;
    maxsize = newmaxsize; 
    if ((newactualsize <= newmaxsize)&&(newactualsize>0))
      actualsize = newactualsize;
    else 
      actualsize = newmaxsize;
    // since allocate() will first call delete[] on data:
    data = 0;
    allocate();
  };

  ~matrix() 
  //////////
  { 
    delete[] data; 
  };

  void comparetoidentity()  
  /////////////////////////
  {
    int   worstdiagonal     = 0;
    float maxunitydeviation = 0.0;
    float currentunitydeviation;
    int i;

    for (i = 0; i < actualsize; i++ )  
    {
      currentunitydeviation = data[i*maxsize+i] - 1.0f;
      if ( currentunitydeviation < 0.0) currentunitydeviation *= -1.0f;
      if ( currentunitydeviation > maxunitydeviation )  
      {
           maxunitydeviation = currentunitydeviation;
           worstdiagonal = i;
      }
    }

    int worstoffdiagonalrow = 0;
    int worstoffdiagonalcolumn = 0;
    float maxzerodeviation = 0.0;
    float currentzerodeviation ;
    for (i = 0; i < actualsize; i++ )  
    {
      for ( int j = 0; j < actualsize; j++ )  
      {
        if ( i == j ) continue;  // we look only at non-diagonal terms
        currentzerodeviation = data[i*maxsize+j];
        if ( currentzerodeviation < 0.0) currentzerodeviation *= -1.0;
        if ( currentzerodeviation > maxzerodeviation )  
        {
          maxzerodeviation = currentzerodeviation;
          worstoffdiagonalrow = i;
          worstoffdiagonalcolumn = j;
        }
      }
    }

    dprintf (L"Worst diagonal value deviation from unity: %f at row/column %d\n" ,maxunitydeviation, worstdiagonal);
    dprintf (L"Worst off-diagonal value deviation from zero: %f at %d, %d\n",  maxzerodeviation, worstoffdiagonalrow, worstoffdiagonalcolumn);
  }


  void settoproduct(matrix& left, matrix& right)  
  ///////////////////////////////////////////////
  {
    actualsize = left.getactualsize();
    if ( maxsize < left.getactualsize() )   
    {
      maxsize = left.getactualsize();
      allocate();
    }

    for ( int i = 0; i < actualsize; i++ )
    {
      for ( int j = 0; j < actualsize; j++ )  
      {
        float sum = 0.0;

        for (int c = 0; c < actualsize; c++)  
        {
          sum += left.data[i * left.maxsize + c] * right.data[c * right.maxsize + j];
        }
        data[i*maxsize + j] = sum;
      }
    }
  }


  void copymatrix(matrix&  source)  
  ////////////////////////////////
  {
    actualsize = source.getactualsize();
    if ( maxsize < source.getactualsize() )  
    {
      maxsize = source.getactualsize();
      allocate();
    }
    for ( int i = 0; i < actualsize; i++ )
    for ( int j = 0; j < actualsize; j++ )  
    {
        float value;
        source.getvalue(i,j,value);
        data[i*maxsize+j] = value;
    }
  };


  void setactualsize(int newactualsize) 
  //////////////////////////////////////
  {
    if ( newactualsize > maxsize )
    {
      maxsize = newactualsize ;
      allocate();
    }

    if (newactualsize >= 0) actualsize = newactualsize;
  };

  int getactualsize() 
  ///////////////////
  { 
    return actualsize; 
  };

  void getvalue(int row, int column, float& returnvalue)   
  //////////////////////////////////////////////////////
  {
    returnvalue = data[ row * maxsize + column ];
  };


  void setvalue(int row, int column, float newvalue)  
  //////////////////////////////////////////////////
  {
    data[ row * maxsize + column ] = newvalue;
  };

  void InvertL ()
  ///////////////
  {
    for (int i = 0; i < actualsize; i++ )  // invert L
    {
      for ( int j = i; j < actualsize; j++ )  
      {
        float x = 1.0;
        if ( i != j ) 
        {
          x = 0.0;
          for ( int k = i; k < j; k++ ) 
              x -= data[j*maxsize+k]*data[k*maxsize+i];
        }
        data[j*maxsize+i] = x / data[j*maxsize+j];
      }
    }
  }

  void InvertU()
  //////////////
  {
    for (int i = 0; i < actualsize; i++ )   // invert U
    {
      for ( int j = i; j < actualsize; j++ )  
      {
        if ( i == j ) continue;
        float sum = 0.0;
        for ( int k = i; k < j; k++ )
            sum += data[k*maxsize+j]*( (i==k) ? 1.0f : data[i*maxsize+k] );
        data[i*maxsize+j] = -sum;
      }
    }
  }



  void InvertFinal (void)
  //////////////////////
  {
    for (int i = 0; i < actualsize; i++ )   // final inversion
    {
      for ( int j = 0; j < actualsize; j++ )  
      {
        float sum = 0.0;
        for ( int k = ((i>j)?i:j); k < actualsize; k++ )  
        {
            sum += ((j==k)?1.0f:data[j*maxsize+k])*data[k*maxsize+i];
        }
        data[j*maxsize+i] = sum;
      }
    }
  }


  void Invert_FirstStep (void)
  ////////////////////////////
  {
    int i;
    for (i=1; i < actualsize; i++) data[i] /= data[0]; // normalize row 0
    for (i=1; i < actualsize; i++)  
    { 
      for (int j=i; j < actualsize; j++)  
      { // do a column of L
        float sum = 0.0;
        for (int k = 0; k < i; k++)  
            sum += data[j*maxsize+k] * data[k*maxsize+i];
        data[j*maxsize+i] -= sum;
      }
      if (i == actualsize-1) continue;
      for (j=i+1; j < actualsize; j++)  
      {  // do a row of U
        float sum = 0.0;
        for (int k = 0; k < i; k++)
            sum += data[i*maxsize+k]*data[k*maxsize+j];

        data[i*maxsize+j] = 
           (data[i*maxsize+j]-sum) / data[i*maxsize+i];
      }
    }
  }

  void invert()  
  /////////////
  {
    Invert_FirstStep();
    InvertL();
    InvertU();
    InvertFinal();
  };
};



int WINAPI WinMain(	HINSTANCE hInstance, HINSTANCE hPrevInstance, LPTSTR lpCmdLine, int nCmdShow)
/////////////////////////////////////////////////////////////////////////////////////////////////
{
  dprintf (L"hello world; this is a test of matrix inversion\n");

  matrix M1(300,300);  // for test we create & invert this matrix
  matrix M2(4,3);      // this will be a copy of original M1
  matrix M3(3,2);      // this will contain the product

  int k = 0;
  rand();  

  for (int i=0; i < M1.getactualsize(); i++)  
  {
    for (int j=0; j<M1.getactualsize(); j++)
    {
      M1.setvalue(i,j,-22+(100.0f * rand())/RAND_MAX);
      k++;
    }
  }
  
  M2.copymatrix(M1);
  long t1 = GetTickCount();
  M1.invert();  
  t1 = GetTickCount()-t1;
  dprintf (L"DONE with matrix inversion!, %d ticks\n", t1);

  long t2 = GetTickCount();
  M3.settoproduct(M1,M2);
  t2 = GetTickCount()-t2;

  dprintf (L"Done with matrix multiply!, %d ticks\n", t2);
  M3.comparetoidentity();

  return 0;
}
