This function compiles the neural network described by the learner object and trains it with the input data.

# S3 method for ruta_autoencoder
train(
  learner,
  data,
  validation_data = NULL,
  metrics = NULL,
  epochs = 20,
  optimizer = "rmsprop",
  ...
)

train(learner, ...)

Arguments

learner

A "ruta_autoencoder" object

data

Training data: columns are attributes and rows are instances

validation_data

Additional numeric data matrix which will not be used for training but the loss measure and any metrics will be computed against it

metrics

Optional list of metrics which will evaluate the model but won't be optimized. See keras::\link[keras]{compile}

epochs

The number of times data will pass through the network

optimizer

The optimizer to be used in order to train the model, can be any optimizer object defined by Keras (e.g. keras::optimizer_adam())

...

Additional parameters for keras::\link[keras]{fit}. Some useful parameters:

  • batch_size The number of examples to be grouped for each gradient update. Use a smaller batch size for more frequent weight updates or a larger one for faster optimization.

  • shuffle Whether to shuffle the training data before each epoch, defaults to TRUE

Value

Same autoencoder passed as parameter, with trained internal models

See also

\link{autoencoder}

Examples

# Minimal example ================================================
# \donttest{
if (keras::is_keras_available())
  iris_model <- train(autoencoder(2), as.matrix(iris[, 1:4]))
# }

# Simple example with MNIST ======================================
# \donttest{
library(keras)
if (keras::is_keras_available()) {
  # Load and normalize MNIST
  mnist = dataset_mnist()
  x_train <- array_reshape(
    mnist$train$x, c(dim(mnist$train$x)[1], 784)
  )
  x_train <- x_train / 255.0
  x_test <- array_reshape(
    mnist$test$x, c(dim(mnist$test$x)[1], 784)
  )
  x_test <- x_test / 255.0

  # Autoencoder with layers: 784-256-36-256-784
  learner <- autoencoder(c(256, 36), "binary_crossentropy")
  train(
    learner,
    x_train,
    epochs = 1,
    optimizer = "rmsprop",
    batch_size = 64,
    validation_data = x_test,
    metrics = list("binary_accuracy")
  )
}
# }