R/validation.R
cross_validation.Rd
cross_validation(
ngme,
type = "k-fold",
seed = NULL,
print = FALSE,
N = 5,
n_gibbs_samples = 100,
n_burnin = 100,
k = 5,
percent = 0.2,
times = 10,
transform = identity,
test_idx = NULL,
train_idx = NULL,
keep_pred = FALSE,
parallel = TRUE,
thining_gap = 1,
cores_layer1 = if (parallel) parallel::detectCores() else 1,
cores_layer2 = if (parallel) parallel::detectCores() else 1
)
a ngme object, or a list of ngme object (if comparing multiple models)
character, in c("k-fold", "loo", "lpo", "custom")
k-fold is k-fold cross-validation, provide k
loo is leave-one-out,
lpo is leave-percent-out, provide percent
from 1 to 100
custom is user-defined group, provide target
and data
random seed
print information during computation
integer, number of simulations (e.g., estimate MAE, MSE, .. N times)
number of gibbs samples
number of burnin
integer (only for k-fold type)
how many percent for testing? from 0 to 1 (for lpo type)
how many test cases (only for lpo type)
a function to transform the data (e.g., log, exp, ...) e.g., the MAE will be computed as |transform(Y) - transform(Y_pred)|
a list of indices of the data (which data points to be predicted) (only for custom type)
a list of indices of the data (which data points to be used for re-sampling (not re-estimation)) (only for custom type)
logical, keep test information (pred_1, pred_2) in the return (as attributes), pred_1 and pred_2 are the prediction of the two chains
logical, run in parallel mode
integer, the gap between samples for thinning, if 0, then no thinning, if 1, then keep 50
cores_layer1integer, number of cores for the first layer (over testing samples)
cores_layer2integer, number of cores for the second layer (over computing scores for N simulations)
1. mean of N estimations of 4 criterions: MSE, MAE, CRPS, sCRPS 2. standard deviation of N estimations of 4 criterions: MSE, MAE, CRPS, sCRPS Compute the cross-validation for the ngme model Perform cross-validation for ngme model first into sub_groups (a list of target, and train data)