class Optional[mxnet.context.Context] = None, epochs: int = 100, batch_size: Optional[int] = None, num_batches_per_epoch: int = 50, learning_rate: float = 0.001, learning_rate_decay_factor: float = 0.5, patience: int = 10, minimum_learning_rate: float = 5e-05, clip_gradient: float = 10.0, weight_decay: float = 1e-08, init: Union[str, mxnet.initializer.Initializer] = 'xavier', hybridize: bool = True, avg_strategy: Union[,] =, metric="score", num_models=1), post_initialize_cb: Optional[Callable[[mxnet.gluon.block.Block], None]] = None)[source]

Bases: object

A trainer specifies how a network is going to be trained.

A trainer is mainly defined by two sets of parameters. The first one determines the number of examples that the network will be trained on (epochs, num_batches_per_epoch and batch_size), while the second one specifies how the gradient updates are performed (learning_rate, learning_rate_decay_factor, patience, minimum_learning_rate, clip_gradient and weight_decay).

  • ctx

  • epochs – Number of epochs that the network will train (default: 100).

  • batch_size – Number of examples in each batch (default: 32).

  • num_batches_per_epoch – Number of batches at each epoch (default: 50).

  • learning_rate – Initial learning rate (default: \(10^{-3}\)).

  • learning_rate_decay_factor – Factor (between 0 and 1) by which to decrease the learning rate (default: 0.5).

  • patience – The patience to observe before reducing the learning rate, nonnegative integer (default: 10).

  • minimum_learning_rate – Lower bound for the learning rate (default: \(5\cdot 10^{-5}\)).

  • clip_gradient – Maximum value of gradient. The gradient is clipped if it is too large (default: 10).

  • weight_decay – The weight decay (or L2 regularization) coefficient. Modifies objective by adding a penalty for having large weights (default \(10^{-8}\)).

  • init – Initializer of the weights of the network (default: “xavier”).

  • hybridize – If set to true the network will be hybridized before training

  • post_initialize_cb – An optional callback function. If provided the function will be called with the initialized network post_initialize_cb(net) before the training starts. This callback can be used to e.g. overwrite parameters for warm starting, to freeze some of the network parameters etc.

count_model_params(net: mxnet.gluon.block.HybridBlock) → int[source]