library pgraph optmum;
graphset;
optset;


/* Calculate log to the base 2 of x. */
proc log2(x);
  retp(log(x)/log(2));
endp;

/* Rountine for matrix exponential. */
proc mexp(a);
  /*
   *************************************************************************
   **   Procedure to calculate matrix exponential for a square matrix A   **
   **---------------------------------------------------------------------**
   **   Uses method of scaling and squaring from Golub & Van Loan.        **   
   **   Special Horner technique (section 11.2) is used for an 8th        **
   **   order approximation. See page 399.                                **
   *************************************************************************
  */
  local a2,a4,c0,c1,c2,c3,c4,c5,c6,c7,c8,scale,ii,k,f,u,v,dqq,nqq,f2;
  /*    Set up coefficients for polynomial approximation                  */
  c0=1;
  c1=0.5;
  c2=0.116666667;
  c3=0.016666667;
  c4=0.001602564;
  c5=0.000106838;
  c6=0.00000485625;
  c7=0.00000013875;
  c8=0.00000000192709;
  /*    Now scale matrix before computing exponential                     */
  scale=1+int(log2(maxc(0.01+maxc(abs(a)))));
  if scale < 0;
    scale = 0;
  endif;
  a=a/2^scale;
  /*     Now compute Polynomial Approximation to matrix exponential       */
  a2=a*a;
  a4=a2*a2;
  ii=eye(rows(a));
  u=c0*ii+c2*a2+(c4*ii+c6*a2+c8*a4)*a4;
  v=c1*ii+c3*a2+(c5*ii+c7*a2)*a4;
  nqq=u+a*v;
  dqq=u-a*v;
  f=inv(dqq)*nqq;
  /*     Now re-scale result                                              */
  k=1;
  do while (k<scale+1);
    f=f*f;
    k=k+1;
  endo;
  retp(f);
endp;

/* Enter covariates. */
x={0 11 31 0 -0.3,
   0 11 31 0 -0.1,
   0 11 31 0  0.1,
   0 11 31 1  0.3,
   0 11 30 0 -0.3,
   0 11 30 0 -0.1,
   0 11 30 0  0.1,
   0 11 30 1  0.3,
   0  6 25 0 -0.3,
   0  6 25 0 -0.1,
   0  6 25 0  0.1,
   0  6 25 1  0.3,
   0  8 36 0 -0.3,
   0  8 36 0 -0.1,
   0  8 36 0  0.1,
   0  8 36 1  0.3,
   0 66 22 0 -0.3,
   0 66 22 0 -0.1,
   0 66 22 0  0.1,
   0 66 22 1  0.3,
   0 27 29 0 -0.3,
   0 27 29 0 -0.1,
   0 27 29 0  0.1,
   0 27 29 1  0.3,
   0 12 31 0 -0.3,
   0 12 31 0 -0.1,
   0 12 31 0  0.1,
   0 12 31 1  0.3,
   0 52 42 0 -0.3,
   0 52 42 0 -0.1,
   0 52 42 0  0.1,
   0 52 42 1  0.3,
   0 23 37 0 -0.3,
   0 23 37 0 -0.1,
   0 23 37 0  0.1,
   0 23 37 1  0.3,
   0 10 28 0 -0.3,
   0 10 28 0 -0.1,
   0 10 28 0  0.1,
   0 10 28 1  0.3,
   0 52 36 0 -0.3,
   0 52 36 0 -0.1,
   0 52 36 0  0.1,
   0 52 36 1  0.3,
   0 33 24 0 -0.3,
   0 33 24 0 -0.1,
   0 33 24 0  0.1,
   0 33 24 1  0.3,
   0 18 23 0 -0.3,
   0 18 23 0 -0.1,
   0 18 23 0  0.1,
   0 18 23 1  0.3,
   0 42 36 0 -0.3,
   0 42 36 0 -0.1,
   0 42 36 0  0.1,
   0 42 36 1  0.3,
   0 87 26 0 -0.3,
   0 87 26 0 -0.1,
   0 87 26 0  0.1,
   0 87 26 1  0.3,
   0 50 26 0 -0.3,
   0 50 26 0 -0.1,
   0 50 26 0  0.1,
   0 50 26 1  0.3,
   0 18 28 0 -0.3,
   0 18 28 0 -0.1,
   0 18 28 0  0.1,
   0 18 28 1  0.3,
   0 111 31 0 -0.3,
   0 111 31 0 -0.1,
   0 111 31 0  0.1,
   0 111 31 1  0.3,
   0 18 32 0 -0.3,
   0 18 32 0 -0.1,
   0 18 32 0  0.1,
   0 18 32 1  0.3,
   0 20 21 0 -0.3,
   0 20 21 0 -0.1,
   0 20 21 0  0.1,
   0 20 21 1  0.3,
   0 12 29 0 -0.3,
   0 12 29 0 -0.1,
   0 12 29 0  0.1,
   0 12 29 1  0.3,
   0  9 21 0 -0.3,
   0  9 21 0 -0.1,
   0  9 21 0  0.1,
   0  9 21 1  0.3,
   0 17 32 0 -0.3,
   0 17 32 0 -0.1,
   0 17 32 0  0.1,
   0 17 32 1  0.3,
   0 28 25 0 -0.3,
   0 28 25 0 -0.1,
   0 28 25 0  0.1,
   0 28 25 1  0.3,
   0 55 30 0 -0.3,
   0 55 30 0 -0.1,
   0 55 30 0  0.1,
   0 55 30 1  0.3,
   0  9 40 0 -0.3,
   0  9 40 0 -0.1,
   0  9 40 0  0.1,
   0  9 40 1  0.3,
   0 10 19 0 -0.3,
   0 10 19 0 -0.1,
   0 10 19 0  0.1,
   0 10 19 1  0.3,
   0 47 22 0 -0.3,
   0 47 22 0 -0.1,
   0 47 22 0  0.1,
   0 47 22 1  0.3,
   1 76 18 0 -0.3,
   1 76 18 0 -0.1,
   1 76 18 0  0.1,
   1 76 18 1  0.3,
   1 38 32 0 -0.3,
   1 38 32 0 -0.1,
   1 38 32 0  0.1,
   1 38 32 1  0.3,
   1 19 20 0 -0.3,
   1 19 20 0 -0.1,
   1 19 20 0  0.1,
   1 19 20 1  0.3,
   1 10 30 0 -0.3,
   1 10 30 0 -0.1,
   1 10 30 0  0.1,
   1 10 30 1  0.3,
   1 19 18 0 -0.3,
   1 19 18 0 -0.1,
   1 19 18 0  0.1,
   1 19 18 1  0.3,
   1 24 24 0 -0.3,
   1 24 24 0 -0.1,
   1 24 24 0  0.1,
   1 24 24 1  0.3,
   1 31 30 0 -0.3,
   1 31 30 0 -0.1,
   1 31 30 0  0.1,
   1 31 30 1  0.3,
   1 14 35 0 -0.3,
   1 14 35 0 -0.1,
   1 14 35 0  0.1,
   1 14 35 1  0.3,
   1 11 27 0 -0.3,
   1 11 27 0 -0.1,
   1 11 27 0  0.1,
   1 11 27 1  0.3,
   1 67 20 0 -0.3,
   1 67 20 0 -0.1,
   1 67 20 0  0.1,
   1 67 20 1  0.3,
   1 41 22 0 -0.3,
   1 41 22 0 -0.1,
   1 41 22 0  0.1,
   1 41 22 1  0.3,
   1  7 28 0 -0.3,
   1  7 28 0 -0.1,
   1  7 28 0  0.1,
   1  7 28 1  0.3,
   1 22 23 0 -0.3,
   1 22 23 0 -0.1,
   1 22 23 0  0.1,
   1 22 23 1  0.3,
   1 13 40 0 -0.3,
   1 13 40 0 -0.1,
   1 13 40 0  0.1,
   1 13 40 1  0.3,
   1 46 33 0 -0.3,
   1 46 33 0 -0.1,
   1 46 33 0  0.1,
   1 46 33 1  0.3,
   1 36 21 0 -0.3,
   1 36 21 0 -0.1,
   1 36 21 0  0.1,
   1 36 21 1  0.3,
   1 38 35 0 -0.3,
   1 38 35 0 -0.1,
   1 38 35 0  0.1,
   1 38 35 1  0.3,
   1  7 25 0 -0.3,
   1  7 25 0 -0.1,
   1  7 25 0  0.1,
   1  7 25 1  0.3,
   1 36 26 0 -0.3,
   1 36 26 0 -0.1,
   1 36 26 0  0.1,
   1 36 26 1  0.3,
   1 11 25 0 -0.3,
   1 11 25 0 -0.1,
   1 11 25 0  0.1,
   1 11 25 1  0.3,
   1 151 22 0 -0.3,
   1 151 22 0 -0.1,
   1 151 22 0  0.1,
   1 151 22 1  0.3,
   1 22 32 0 -0.3,
   1 22 32 0 -0.1,
   1 22 32 0  0.1,
   1 22 32 1  0.3,
   1 41 25 0 -0.3,
   1 41 25 0 -0.1,
   1 41 25 0  0.1,
   1 41 25 1  0.3,
   1 32 35 0 -0.3,
   1 32 35 0 -0.1,
   1 32 35 0  0.1,
   1 32 35 1  0.3,
   1 56 21 0 -0.3,
   1 56 21 0 -0.1,
   1 56 21 0  0.1,
   1 56 21 1  0.3,
   1 24 41 0 -0.3,
   1 24 41 0 -0.1,
   1 24 41 0  0.1,
   1 24 41 1  0.3,
   1 16 32 0 -0.3,
   1 16 32 0 -0.1,
   1 16 32 0  0.1,
   1 16 32 1  0.3,
   1 22 26 0 -0.3,
   1 22 26 0 -0.1,
   1 22 26 0  0.1,
   1 22 26 1  0.3,
   1 25 21 0 -0.3,
   1 25 21 0 -0.1,
   1 25 21 0  0.1,
   1 25 21 1  0.3,
   1 13 36 0 -0.3,
   1 13 36 0 -0.1,
   1 13 36 0  0.1,
   1 13 36 1  0.3,
   1 12 37 0 -0.3,
   1 12 37 0 -0.1,
   1 12 37 0  0.1,
   1 12 37 1  0.3};

x[.,2]=x[.,2]/4;
x[.,3]=ln(x[.,3]);

x_t=ones(236,1);

x=x_t~x;

ry=rows(x);

/* 
Columns of X are: constant, treatment, base/4, age.
*/
x=x[.,1:4];

/* Set x_t=x_t1=x_t2=0, to save space. */
x_t=0;

/* Number of epilepsy seizures. */
y={5,3,3,3,
   3,5,3,3,
   2,4,0,5,
   4,4,1,4,
   7,18,9,21,
   5,2,8,7,
   6,4,0,2,
   40,20,23,12,
   5,6,6,5,
   14,13,6,0,
   26,12,6,22,
   12,6,8,4,
   4,4,6,2,
   7,9,12,14,
   16,24,10,9,
   11,0,0,5,
   0,0,3,3,
   37,29,28,29,
   3,5,2,5,
   3,0,6,7,
   3,4,3,4,
   3,4,3,4,
   2,3,3,5,
   8,12,2,8,
   18,24,76,25,
   2,1,2,1,
   3,1,4,2,
   13,15,13,12,
   11,14,9,8,
   8,7,9,4,
   0,4,3,0,
   3,6,1,3,
   2,6,7,4,
   4,3,1,3,
   22,17,19,16,
   5,4,7,4,
   2,4,0,4,
   3,7,7,7,
   4,18,2,5,
   2,1,1,0,
   0,2,4,0,
   5,4,0,3,
   11,14,25,15,
   10,5,3,8,
   19,7,6,7,
   1,1,2,3,
   6,10,8,8,
   2,1,0,0,
   102,65,72,63,
   4,3,2,4,
   8,6,5,7,
   1,3,1,5,
   18,11,28,13,
   6,3,4,0,
   3,5,4,3,
   1,23,19,8,
   2,3,0,1,
   0,0,0,0,
   1,4,3,2};

c=(ry/4);

/*
Generate appropriate X matrix for each time point.  Also generate
Y matrix with column i representing number of seizures in the ith
time period.
*/
i=0;
do while(i<c);
  i=i+1;
  j=(i-1)*4+1;
  if i==1;
    x_t=x[j,.];
    y1=y[j];
	y2=y[j+1];
	y3=y[j+2];
	y4=y[j+3];
  else;
    x_t=x_t|x[j,.];
    y1=y1|y[j];
	y2=y2|y[j+1];
	y3=y3|y[j+2];
	y4=y4|y[j+3];
  endif;
endo;

x=x_t;

ry=ry/4;

y=y1~y2~y3~y4;

x1=x;
cx=cols(x1);
x=0;

/*
Setup a column for the interaction between treatment and log(baseline/4).
*/
x_t=x1[.,1];
i=0;
do while(i<ry);
  i=i+1;
  x_t[i]=x1[i,2]*ln(x1[i,3]);
endo;

x_t1=x1~x_t;

x1=x_t1;

/*
Time dependent covariates.
*/
x2=x_t1~((y[.,1]+0.1)/meanc(y[.,1]+0.1));

x3=x_t1~((y[.,2]+0.1)/meanc(y[.,2]+0.1))~((y[.,1]+0.1)/meanc(y[.,1]+0.1));

x4=x_t1~((y[.,3]+0.1)/meanc(y[.,3]+0.1))~((y[.,2]+0.1)/meanc(y[.,2]+0.1))~((y[.,1]+0.1)/meanc(y[.,1]+0.1));

x_t1=0;

/*
theta[1]=coefficient for constant.
theta[2]=coefficient for treatment.
theta[3]=coefficient for log(baseline/4).
theta[4]=coefficient for age.
theta[5]=coefficient for treatment*log(baseline/4).
theta[6]=b parameter for time 1 data.
theta[7]=b parameter for time 2 data.
theta[8]=b parameter for time 3 data.
theta[9]=b parameter for time 4 data.
theta[10]=c parameter for time 2 data.
theta[11]=c parameter for time 4 data.
theta[12]=common coefficient for 4/3, 3/2 and 2/1 ((y[.,t-1]+0.1)/meanc(y[.,t-1]+0.1)).
theta[13]=common coefficient for 4/2, 3/1 ((y[.,t-2]+0.1)/meanc(y[.,t-2]+0.1)).
*/

theta={-0.55,
    -0.76,
     0.71,
     0.39,
     0.29,
     0.99,
     0.60,
     0.71,
    -3.08,
     0.74,
     0.16,
     0.16,
     0.10};

ry=rows(y);

npar=rows(theta);

c_mu=0;

_opmdmth="BFGS";
{theta,minrss,grad,rc}=optprt(optmum(&log_likelihood,theta));

/* Output to file output.asc. */
output file = output.asc reset;
screen off;
format /rd 12,5;
open fp = output;
print;
print "Observed Data";
print y;
print;
print "Log-likelihood";
print minrss;
print;
print "Estimated Theta";
print theta;
print;
print "Number of parameters";
print npar;
jj=hessp(&log_likelihood,theta);
jj1=inv(jj);
print;
print "Variance-covariance Matrix - Delta (z)";
print jj1;
print;
print "Standard errors of the parameters";
jj2=sqrt(diag(jj1));
print jj2;

fp = close(fp);
end;

/*
Function to generate the matrix that is to be exponentiated.
*/
proc mgen(theta,x,y,b,e,t);
  local mw,i,mu,temp;
  mw=eye(y+1);
  if t==1;
    mu=(x[3]^theta[b+2])*exp(x[1:2]*theta[b:b+1]+x[4]*theta[e]+x[5]*theta[e+1]);
    c_mu=mu;
    temp=ln(mu+exp(theta[e+1+t]))-ln(exp(theta[e+1+t]));
  elseif t==2;
    mu=(x[3]^theta[b+2])*(x[6]^theta[12])*exp(x[1:2]*theta[b:b+1]+x[4]*theta[e]+x[5]*theta[e+1]);
    c_mu=mu;
    temp=((((mu/exp(theta[e+1+t]))+1)^(1-theta[t+e+4]))-1)*((exp(theta[e+1+t])^(1-theta[t+e+4]))/(1-theta[t+e+4]));
  elseif t==3;
    mu=(x[3]^theta[b+2])*(x[6]^theta[12])*exp(x[1:2]*theta[b:b+1]+x[4]*theta[e]+x[5]*theta[e+1]);
    mu=mu*(x[7]^theta[13]);
    c_mu=mu;
    temp=ln(mu+exp(theta[e+1+t]))-ln(exp(theta[e+1+t]));
  elseif t==4;
    mu=(x[3]^theta[b+2])*(x[6]^theta[12])*exp(x[1:2]*theta[b:b+1]+x[4]*theta[e]+x[5]*theta[e+1]);
    mu=mu*(x[7]^theta[13]);
    c_mu=mu;
    temp=((((mu/exp(theta[e+1+t]))+1)^(1-theta[t+e+3]))-1)*((exp(theta[e+1+t])^(1-theta[t+e+3]))/(1-theta[t+e+3]));
  endif;
  i=0;
  do while (i<y+1);
    i=i+1;
    /*  We subtract the 1 from i in the following line because the first
         row represents n=0, so the last row represent n, not n+1
    */
    if t==1 or t==3;
      mw[i,i]=-temp*(exp(theta[e+1+t])+i-1);
    elseif t==2;
      mw[i,i]=-temp*(exp(theta[e+1+t])+i-1)^theta[t+e+4];
    elseif t==4;
      mw[i,i]=-temp*(exp(theta[e+1+t])+i-1)^theta[t+e+3];
    endif;
    if i<y+1;
      mw[i,i+1]=-mw[i,i];
    endif;
  endo;
  retp(mw);
endp;

/*
Function to calculate log-likelihood.
*/
proc log_likelihood(theta);
  local a,i,ll,pred,j,p0,ll1;
  j=0;
  ll=0;
  do while (j<ry);
    j=j+1;
    a=mgen(theta,x1[j,.],y[j,1],1,4,1);
    p0=zeros(y[j,1]+1,1);
    p0=p0';
    p0[1]=1;
    pred=p0*mexp(a);
    i=0;
    do while (i<y[j,1]+1);
      i=i+1;
      if pred[i]<(10^(-300));
        pred[i]=10^(-300);
      endif;
    endo;
    ll=ll+ln(pred[y[j,1]+1]);
  endo;
  ll1=0;
  j=0;
  do while (j<ry);
    j=j+1;
    a=mgen(theta,x2[j,.],y[j,2],1,4,2);
    p0=zeros(y[j,2]+1,1);
    p0=p0';
    p0[1]=1;
    pred=p0*mexp(a);
    i=0;
    do while (i<y[j,2]+1);
      i=i+1;
      if pred[i]<(10^(-300));
        pred[i]=10^(-300);
      endif;
    endo;
    ll1=ll1+ln(pred[y[j,2]+1]);
  endo;
  ll=ll+ll1;
  ll1=0;
  j=0;
  do while (j<ry);
    j=j+1;
    a=mgen(theta,x3[j,.],y[j,3],1,4,3);
    p0=zeros(y[j,3]+1,1);
    p0=p0';
    p0[1]=1;
    pred=p0*mexp(a);
    i=0;
    do while (i<y[j,3]+1);
      i=i+1;
      if pred[i]<(10^(-300));
        pred[i]=10^(-300);
      endif;
    endo;
    ll1=ll1+ln(pred[y[j,3]+1]);
  endo;
  ll=ll+ll1;
  ll1=0;
  j=0;
  do while (j<ry);
    j=j+1;
    a=mgen(theta,x4[j,.],y[j,4],1,4,4);
    p0=zeros(y[j,4]+1,1);
    p0=p0';
    p0[1]=1;
    pred=p0*mexp(a);
    i=0;
    do while (i<y[j,4]+1);
      i=i+1;
      if pred[i]<(10^(-300));
        pred[i]=10^(-300);
      endif;
    endo;
    ll1=ll1+ln(pred[y[j,4]+1]);
  endo;
  ll=ll+ll1;
  ll=-ll;
  retp(ll);
endp;
