Skip to contents

These are configurations for ngme optimization process.

Usage

control_opt(
  seed = Sys.time(),
  burnin = 100,
  iterations = 500,
  estimation = TRUE,
  standardize_fixed = TRUE,
  n_batch = 10,
  iters_per_check = iterations/n_batch,
  optimizer = adam(),
  start_sd = 0.5,
  n_parallel_chain = 4,
  max_num_threads = n_parallel_chain,
  print_check_info = FALSE,
  max_relative_step = 0.5,
  max_absolute_step = 0.5,
  rao_blackwellization = FALSE,
  n_trace_iter = 10,
  sampling_strategy = "all",
  solver_backend = if (Sys.info()["sysname"] == "Darwin") "accelerate" else "cholmod",
  solver_type = "llt",
  verbose = FALSE,
  store_traj = TRUE,
  robust = FALSE,
  n_min_batch = 3,
  n_slope_check = 3,
  trend_std_conv_check = TRUE,
  std_lim = 0.01,
  trend_lim = 0.01,
  R_hat_conv_check = TRUE,
  max_R_hat = 1.1,
  pflug_conv_check = TRUE,
  pflug_alpha = 0.9
)

Arguments

seed

set the seed for pesudo random number generator

burnin

interations for burn-in periods (before optimization)

iterations

optimization iterations

estimation

run the estimation process (call C++ in backend)

standardize_fixed

whether or not standardize the fixed effect

n_batch

number of checkpoints; optimization is split into n_batch equal batches

iters_per_check

run how many iterations between each check point (or specify n_batch)

optimizer

choose different sgd optimization method, currently support "sgd", "precond_sgd", "momentum", "adagrad", "rmsprop", "adam", "adamW" see ?sgd, ?precond_sgd, ?momentum, ?adagrad, ?rmsprop, ?adam, ?adamW

start_sd

standard deviation of the initial parameter (1st chain fixed, other chains random), set 0 to be fixed for all chains

n_parallel_chain

number of parallel chains

max_num_threads

maximum number of threads used for parallel computing, by default will be set same as n_parallel_chain. If it is more than n_parallel_chain, the rest will be used to parallel different replicates of the model.

print_check_info

print the convergence information

max_relative_step

max relative step allowed in 1 iteration

max_absolute_step

max absolute step allowed in 1 iteration

rao_blackwellization

use rao_blackwellization

n_trace_iter

use how many iterations to approximate the trace (Hutchinson’s trick)

sampling_strategy

subsampling method of replicates of model, c("all", "is") "all" means using all replicates in each iteration, "ws" means weighted sampling (each iteration use 1 replicate to compute the gradient, the sample probability is proption to its number of observations)

solver_backend

backend in ("eigen", "cholmod", "accelerate", "pardiso")

solver_type

factorization type: "llt" or "ldlt"

verbose

print estimation

store_traj

store the optimizer trajectory for diagnostics (set FALSE to reduce memory)

n_min_batch

minimum number of checkpoints before any convergence diagnostic is attempted

n_slope_check

number of checkpoints used as the regression window for the trend test

trend_std_conv_check

enable the trend/std diagnostic (uses std_lim, trend_lim, n_slope_check)

std_lim

maximum allowed standard deviation

trend_lim

maximum allowed slope

max_R_hat

maximum allowed R_hat

pflug_conv_check

use Pflug diagnostic for convergence check

pflug_alpha

scaling factor (0-1] for Pflug criterion: require pflug_sum < pflug_alpha * max_pflug_sum

max_R_hat_conv_check

use max_R_hat for convergence check

Value

list of control variables

Details

Convergence diagnostics (multi-chain): * R-hat: per-parameter Gelman–Rubin statistic; passes if R_hat <= max_R_hat. * Trend/Std: uses the last n_slope_check checkpoints after at least n_min_batch batches. Passes when both the relative std (sqrt(var)/|mean| <= std_lim) and linear trend of the means (|slope| <= trend_lim) satisfy their thresholds. * Pflug: per-chain criterion pflug_sum < pflug_alpha * max_pflug_sum in the latest batch; if all chains satisfy it, overall convergence is declared. Checks are evaluated every iters_per_check = iterations / n_batch. A parameter is marked converged if any enabled parameter-level diagnostic (R-hat or Trend/Std) passes; the run stops when all parameters converge or when the Pflug diagnostic triggers.