require(Matrix)




# multinomial log likelihood
logL = function(B, b, X, Y) {
  eta = matrix(1, nrow(X), 1) %*% t(b) + X %*% B
  sum(Y*(eta - log(rowSums(exp(eta)))))
}


# nuclear norm
nuclear = function(B) {
  sum(abs(svd(B)$d))
}




# objective for nuclear-norm penalized multinomial regression
objective = function(B, b, X, Y, lambda) {
  -logL(B, b, X, Y) + lambda*nuclear(B)
}


objectiveFast = function(B, P, W, lambda) {
  -sum(log(P[W])) + lambda*sum(abs(svd(B)$d))
}


prox = function(B, threshold, group) {
  B. = B
  for (g in 1:max(group)) {
    SVD = svd(B[group == g, ])
    D = (SVD$d - threshold)*(SVD$d - threshold > 0)
    B.[group == g, ] = SVD$u %*% (D * t(SVD$v))
  }
  B.
}




PGDnpmr = function(B, b, X, Y, lambda, s, group = NULL, accelerated = TRUE,
  eps = 1e-7, maxit = 1e5, quiet = TRUE) {

  if (is.null(group)) group = rep(1, nrow(B))
  sumY = colSums(Y)
  XtY = crossprod(X, Y)
  W = which(Y == 1)
  eta = matrix(1, nrow(X), 1) %*% t(b) + X %*% B
  P = exp(eta)/rowSums(exp(eta))

  if (accelerated) {
    C = B
    c = b
  }

  objectivePath = c(objectiveFast(B, P, W, lambda), rep(NA, maxit))
  it = 0
  diff = eps + 1
  Sys.time = Sys.time()
  while(abs(diff) > eps & it < maxit) {
    it = it + 1
    if (!quiet) print(objectivePath[it])
    B. = prox(B + s*XtY - s*crossprod(X, P), s*lambda, group)
    b. = b + s*sumY - s*colSums(P)
    if (accelerated) {
      C. = B.
      c. = b.
      B. = C. + it/(it+3)*(C. - C)
      b. = c. + it/(it+3)*(c. - c)
      C = C.
      c = c.
    }
    eta = t(t(as.matrix(X%*%B.)) + b.)
    expeta = exp(eta)
    P. = expeta/rowSums(expeta)
    objectivePath[it+1] = objectiveFast(B., P., W, lambda)
    while(objectivePath[it+1] > objectivePath[it]){
      if (!quiet) print(paste('s =', s))
      s = s/2
      B. = prox(B + s*XtY - s*crossprod(X, P), s*lambda, group)
      b. = b + s*sumY - s*colSums(P)
      if (accelerated) {
        C. = B.
        c. = b.
        B. = C. + it/(it+3)*(C. - C)
        b. = c. + it/(it+3)*(c. - c)
        C = C.
        c = c.
      }
      eta = t(t(as.matrix(X%*%B.)) + b.)
      expeta = exp(eta)
      P. = expeta/rowSums(expeta)
      objectivePath[it+1] = objectiveFast(B., P., W, lambda)
    }
    B = B. - mean(B.)
    b = b. - mean(b.)
    P = P.
    diff = (objectivePath[it] - objectivePath[it+1])/objectivePath[it+1]
  }

  list(B = B, b = b, objectivePath = objectivePath[1:(it+1)],
    time = Sys.time() - Sys.time)
}




npmr = function(X, Y, lambda, s = 0.1, eps = 1e-6, group = NULL,
  accelerated = TRUE, B.init = NULL, b.init = NULL, quiet = TRUE) {

  if (is.null(dim(Y))) {
    colnames = sort(unique(Y))
    Y = model.matrix(~ as.factor(Y) - 1)
    colnames(Y) = colnames
  }

  if (is.null(B.init)) {
#    B = matrix(rnorm(ncol(X)*ncol(Y)), ncol(X), ncol(Y))
    B = matrix(1, ncol(X), ncol(Y)) # remove
  } else B = B.init

  if (is.null(b.init)) {
    b = log(colMeans(Y)) - mean(log(colMeans(Y)))
  } else b = b.init

  B.path = array(NA, dim = c(ncol(X), ncol(Y), length(lambda)))
  b.path = array(NA, dim = c(ncol(Y), length(lambda)))
  objective.path = rep(NA, length(lambda))

  cat('Progress: ')
  for (l in 1:length(lambda)) {
    solution = PGDnpmr(B, b, X, Y, lambda[l], s = s, group = group,
      accelerated = accelerated, eps = eps, quiet = quiet)
    B = solution$B
    b = solution$b
    B.path[, , l] = as.matrix(B)
    b.path[, l] = scale(b, scale = FALSE)
    objective.path[l] = min(solution$objectivePath)
    cat(round(100*l/length(lambda)))
    cat('% ')
  }
  cat('\n')

  fit = list(B = B.path, b = b.path, objective = objective.path,
    lambda = lambda)
  class(fit) = 'npmr'
  return(fit)
}




print.npmr = function(x, ...) {

  rank = rep(NA, length(x$lambda))

  for (l in 1:length(x$lambda)) {
    rank[l] = rankMatrix(x$B[,,l])*(max(abs(x$B[,,l]) > 0))
  }

  print(data.frame(lambda = x$lambda, rank = rank, objective = x$objective))
}




predict.npmr = function(object, newx, ...) {
  nlambda = ncol(object$b)
  eta = P = array(NA, c(nrow(newx), dim(object$B)[2], nlambda))
  for (l in 1:nlambda) {
    eta[, , l] = as.matrix(matrix(1, nrow(newx), 1) %*% t(object$b[, l]) +
      newx %*% object$B[, , l])
    P[, , l] = exp(eta[, , l])/rowSums(exp(eta[, , l]))
  }
  P
}


plot.npmr = function(x, lambda, ...) {

  l = which.min((x$lambda - lambda)^2)

  SVD = svd(x$B[,,l])
}


cv.npmr = function(X, Y, lambda, s = 0.1, eps = 1e-4, group = NULL,
  accelerated = TRUE, B.init = NULL, b.init = NULL, foldid = NULL,
  nfolds = 10) {

  if (is.null(dim(Y))) {
    colnames = sort(unique(Y))
    Y = model.matrix(~ as.factor(Y) - 1)
    colnames(Y) = colnames
  }

  if (is.null(foldid)) foldid = sample(rep(1:nfolds, length = nrow(X)))
  nfolds = length(unique(foldid))

  pred = array(NA, c(nrow(X), ncol(Y), length(lambda)))

  for (fold in unique(foldid)) {
    fit = npmr(X[foldid != fold, ], Y[foldid != fold, ], lambda,
      s = s, eps = eps, group = group, accelerated = accelerated,
      B.init = B.init, b.init = b.init)
    pred[foldid == fold, , ] = predict(fit, X[foldid == fold, ])
  }

  deviance = -2 * log(apply(pred, 3, function(pred) {rowSums(pred * Y)}))
  print(dim(deviance))

  ell.min = which.min(colSums(deviance))
  lambda.min = lambda[ell.min]
  if (lambda.min == min(lambda)) {
    warning('lambda chosen through CV is smallest lambda')
  } else if (lambda.min == max(lambda)) {
    warning('lambda chosen through CV is largest lambda')
  }
    
  fit = list(deviance = deviance, cvm = colMeans(deviance),
    cvsd = apply(deviance, 2, sd)/sqrt(nrow(deviance)),
    fit = npmr(X, Y, lambda.min, s = s, eps = eps, group = group,
      accelerated = accelerated, B.init = fit$B[, , ell.min],
      b.init = fit$b[, ell.min]), lambda.min = lambda.min)
  class(fit) = "cv.npmr"
  return(fit)
}




predict.cv.npmr = function(object, newx) {
  predict(object$fit, newx)[,,1]
}

