R/autoencoder.R
, R/generics.R
train.ruta_autoencoder.Rd
This function compiles the neural network described by the learner object and trains it with the input data.
# S3 method for ruta_autoencoder
train(
learner,
data,
validation_data = NULL,
metrics = NULL,
epochs = 20,
optimizer = "rmsprop",
...
)
train(learner, ...)
A "ruta_autoencoder"
object
Training data: columns are attributes and rows are instances
Additional numeric data matrix which will not be used for training but the loss measure and any metrics will be computed against it
Optional list of metrics which will evaluate the model but
won't be optimized. See keras::\link[keras]{compile}
The number of times data will pass through the network
The optimizer to be used in order to train the model, can
be any optimizer object defined by Keras (e.g. keras::optimizer_adam()
)
Additional parameters for keras::\link[keras]{fit}
. Some useful parameters:
batch_size
The number of examples to be grouped for each gradient update.
Use a smaller batch size for more frequent weight updates or a larger one for
faster optimization.
shuffle
Whether to shuffle the training data before each epoch, defaults to TRUE
Same autoencoder passed as parameter, with trained internal models
\link{autoencoder}
# Minimal example ================================================
# \donttest{
if (keras::is_keras_available())
iris_model <- train(autoencoder(2), as.matrix(iris[, 1:4]))
# }
# Simple example with MNIST ======================================
# \donttest{
library(keras)
if (keras::is_keras_available()) {
# Load and normalize MNIST
mnist = dataset_mnist()
x_train <- array_reshape(
mnist$train$x, c(dim(mnist$train$x)[1], 784)
)
x_train <- x_train / 255.0
x_test <- array_reshape(
mnist$test$x, c(dim(mnist$test$x)[1], 784)
)
x_test <- x_test / 255.0
# Autoencoder with layers: 784-256-36-256-784
learner <- autoencoder(c(256, 36), "binary_crossentropy")
train(
learner,
x_train,
epochs = 1,
optimizer = "rmsprop",
batch_size = 64,
validation_data = x_test,
metrics = list("binary_accuracy")
)
}
# }