/************************************************************************/
/*                                                                      */
/*   kernel.h                                                           */
/*                                                                      */
/*   User defined kernel function. Feel free to plug in your own.       */
/*                                                                      */
/*   Copyright: Thorsten Joachims                                       */
/*   Date: 16.12.97                                                     */
/*                                                                      */
/************************************************************************/

/* KERNEL_PARM is defined in svm_common.h The field 'custom' is reserved for */
/* parameters of the user defined kernel. You can also access and use */
/* the parameters of the other kernels. Just replace the line 
             return((double)(1.0)); 
   with your own kernel. */

  /* Example: The following computes the polynomial kernel. sprod_ss
              computes the inner product between two sparse vectors. 

      return((CFLOAT)pow(kernel_parm->coef_lin*sprod_ss(a->words,b->words)
             +kernel_parm->coef_const,(double)kernel_parm->poly_degree)); 
  */

/* If you are implementing a kernel that is not based on a
   feature/value representation, you might want to make use of the
   field "userdefined" in SVECTOR. By default, this field will contain
   whatever string you put behind a # sign in the example file. So, if
   a line in your training file looks like

   -1 1:3 5:6 #abcdefg

   then the SVECTOR field "words" will contain the vector 1:3 5:6, and
   "userdefined" will contain the string "abcdefg". */


/* --- jvdmatrix.h --------------------------------------------------------- */
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#define sub2ind(M,i,j)    ((i)+(j)*(M)->rowc)
#define ind2row(M,i)      ((i)%(M)->rowc)
#define ind2col(M,i)      ((i)/(M)->rowc)
#define subisvalid(M,i,j) ((i)<(M)->rowc && (j)<((M)->valc/(M)->rowc))

typedef struct {
	double *val;
	int rowc;
	int valc;
} JVDMATRIX;

JVDMATRIX* jvdmatrix_create(int rowc, int colc);
void jvdmatrix_destroy(JVDMATRIX **M);
void jvdmatrix_fprint_vect( FILE *fd, JVDMATRIX *M);
void jvdmatrix_fprint_mat( FILE *fd, JVDMATRIX *M);
JVDMATRIX* jvdmatrix_fread( FILE *fd, JVDMATRIX **M );




/* --- kernel.h ------------------------------------------------------------ */
/* jvd: Note: K is never deleted because that require modifying main() */
JVDMATRIX *K;
double custom_kernel(KERNEL_PARM *kernel_parm, SVECTOR *a, SVECTOR *b) {
	FILE *fdin;
	long int i,j;

	if( kernel_parm->custom[0]=='\0' ) {
		/* jvd: immediately fall out */
	} else if( strcmp(kernel_parm->custom,"empty")==0 ) {
		fprintf(stderr,"ERROR: specify matrix file\n");
		exit(1);
	} else {
		/* jvd: read in matrix */
		fprintf(stdout,"Slurping \"%s\".\n",kernel_parm->custom);
		fdin = fopen(kernel_parm->custom,"r");
		if (fdin==NULL) {
			fprintf(stderr,"ERROR: could not open %s\n",kernel_parm->custom);
			exit(2);
		}
		/* jvd: known memory leak; jvdmatrix_destroy(K) is never called */
		jvdmatrix_fread( fdin, &K );	
		fclose( fdin );
		kernel_parm->custom[0]='\0';
		/* jvd: finally, fall out */
	}

	/* jvd: this is what I had originally */
	/*i = (a->words[0].wnum==(FNUM)0) ? 0 : atol(a->userdefined);
	j = (b->words[0].wnum==(FNUM)0) ? 0 : atol(b->userdefined);*/

	/* jvd: changed to this because thats what guy used */
	i = (a->words->wnum<=(FNUM)0) ? 0 : atol(a->userdefined);
	j = (b->words->wnum<=(FNUM)0) ? 0 : atol(b->userdefined);
	if ( !subisvalid(K,i,j) ) {
		fprintf(stderr,"ERROR: K(%d,%d) doesn't exist!\n",i,j);
		exit(3);
	}
	
	/*fprintf(stdout,"Returning K(%d,%d) = %lf\n",i,j,K->val[sub2ind(K,i,j)]);*/

	return K->val[sub2ind(K,i,j)];
}



/* --- jvdmatrix.c --------------------------------------------------------- */
JVDMATRIX* jvdmatrix_create(int rowc, int colc) {
	JVDMATRIX *M;
	M = (JVDMATRIX*)calloc(1,sizeof(JVDMATRIX));
	assert(M!=NULL);
	M->val = (double*)calloc(rowc*colc,sizeof(double));
	assert(M->val!=NULL);
	M->rowc = rowc;
	M->valc = rowc*colc;
	return M;
}


void jvdmatrix_destroy(JVDMATRIX **M) {
	if (M!=NULL && *M!=NULL) {
		free((*M)->val);
		(*M)->val = (double*)0xDEADBEEF;
		(*M)->rowc = 0;
		(*M)->valc = 0;
		free(*M);
		*M = (JVDMATRIX*)0xDEADBEEF;
	}
}


void jvdmatrix_fprint_vect( FILE *fd, JVDMATRIX *M) {
	int i = M->valc;
	fprintf(fd,"%d %d\n",M->rowc,M->valc/M->rowc);
	for( i=0; i<M->valc; i++ )
		fprintf(fd,"%lf\n",M->val[i]);
}


void jvdmatrix_fprint_mat( FILE *fd, JVDMATRIX *M) {
	int r,c,colc,rowc;
	assert( M!=NULL );
	rowc = M->rowc;
	colc = M->valc/M->rowc;
	fprintf(fd,"%d %d\n",rowc,colc);
	for (r=0;r<rowc;r++) {
		if (colc>0) {
			fprintf(fd,"%lf",M->val[sub2ind(M,r,0)]);
			for (c=1;c<colc;c++)
				fprintf(fd,"\t%lf",M->val[sub2ind(M,r,c)]);
		}
		fprintf(fd,"\n");
	}
}

/* >> % reads matrices as formatted by:
 * >> fprintf('%d %d\n',size(M)); fprintf('%f\n',reshape(M,numel(M),1));
 */
JVDMATRIX* jvdmatrix_fread( FILE *fd, JVDMATRIX **M ) {
	int i,rowc,colc;
	double *val;
	fscanf(fd,"%d %d\n",&rowc,&colc);
	*M = jvdmatrix_create( rowc, colc );
	val = (*M)->val;
	for( i=0; i<(*M)->valc; i++ )
		fscanf(fd,"%lf\n",&(val[i]));
	return *M;
}


