Source code for pyro.optim.lr_scheduler

from __future__ import absolute_import, division, print_function

from pyro.optim.optim import PyroOptim


[docs]class PyroLRScheduler(PyroOptim): """ A wrapper for torch.optim.lr_scheduler objects that adjust learning rates for dynamically generated parameters. :param optim_constructor: a torch.optim.lr_scheduler :param optim_args: a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries. must contain the key 'optimizer' with pytorch optimizer value Example:: optimizer = torch.optim.SGD pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1}) svi = SVI(model, guide, pyro_scheduler, loss=TraceGraph_ELBO()) svi.step() """ def __init__(self, scheduler_constructor, optim_args): # pytorch scheduler self.pt_scheduler_constructor = scheduler_constructor # torch optimizer pt_optim_constructor = optim_args.pop('optimizer') # kwargs for the torch optimizer optim_kwargs = optim_args.pop('optim_args') self.kwargs = optim_args # current epoch self.epoch = None super(PyroLRScheduler, self).__init__(pt_optim_constructor, optim_kwargs) def __call__(self, params, *args, **kwargs): kwargs['epoch'] = self.epoch super(PyroLRScheduler, self).__call__(params, *args, **kwargs) def _get_optim(self, params): optim = super(PyroLRScheduler, self)._get_optim(params) return self.pt_scheduler_constructor(optim, **self.kwargs)
[docs] def set_epoch(self, epoch): self.epoch = epoch