class StepDropout(step_size, base_drop_rate, gamma=0.0, update_interval='epoch', log=True, log_name='drop_rate', ascending=False, **kwargs)

Step Dropout.

A simple Dropout Scheduler.



>>> from pytorch_lightning import Trainer
>>> # Early Dropout (drop rate from .1 to 0 after 50 epochs)
>>> trainer = Trainer(callbacks=[StepDropout(50, base_drop_rate=.1, gamma=0.)])
>>> # Late Dropout (drop rate from 0 to .1 after 50 epochs)
>>> trainer = Trainer(callbacks=[StepDropout(50, base_drop_rate=.1, gamma=0., ascending=True)])
  • step_size – Period of drop rate decay.

  • base_drop_rate – Base drop rate.

  • gamma – Multiplicative factor of drop rate decay. Default: 0. to replicate “Early Dropout”.

  • update_interval – One of ('step', 'epoch').

  • log – Whether to log drop rates using module.log(log_name, drop_rate).

  • log_name – Name for logging.

  • logger – If True logs to the logger.

  • ascending – If True drop rate decays from right to left, i.e. it starts at 0 and ascends towards base_drop_rate. Using ascending=True, gamma=0. replicates “Late Dropout”.

  • **kwargs – Keyword arguments for module.log.

static get_rate(base, gamma, step, step_size, ascending)
on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None
on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None
update_drop_rate(pl_module: LightningModule, drop_rate: float)