import torch.optim as optim import util import torch from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR, ExponentialLR class Trainer(): def __init__(self, model, lrate, wdecay, clip, step_size, seq_out_len, scaler, device, cl=True): self.scaler = scaler self.model = model self.model.to(device) self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay) self.loss = util.masked_mae self.clip = clip self.step = step_size self.iter = 1 self.task_level = 1 self.seq_out_len = seq_out_len self.cl = cl def train(self, input, real_val, idx=None): self.model.train() self.optimizer.zero_grad() output = self.model(input, idx=idx) output = output.transpose(1,3) real = torch.unsqueeze(real_val,dim=1) predict = self.scaler.inverse_transform(output) if self.iter%self.step==0 and self.task_level<=self.seq_out_len: self.task_level +=1 if self.cl: loss = self.loss(predict[:, :, :, :self.task_level], real[:, :, :, :self.task_level], 0.0) else: loss = self.loss(predict, real, 0.0) loss.backward() if self.clip is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) self.optimizer.step() # mae = util.masked_mae(predict,real,0.0).item() mape = util.masked_mape(predict,real,0.0).item() rmse = util.masked_rmse(predict,real,0.0).item() self.iter += 1 return loss.item(),mape,rmse def eval(self, input, real_val): self.model.eval() output = self.model(input) output = output.transpose(1,3) real = torch.unsqueeze(real_val,dim=1) predict = self.scaler.inverse_transform(output) loss = self.loss(predict, real, 0.0) mape = util.masked_mape(predict,real,0.0).item() rmse = util.masked_rmse(predict,real,0.0).item() return loss.item(),mape,rmse class Optim(object): def _makeOptimizer(self): if self.method == 'sgd': self.optimizer = optim.SGD(self.params, lr=self.lr, weight_decay=self.lr_decay) elif self.method == 'adagrad': self.optimizer = optim.Adagrad(self.params, lr=self.lr, weight_decay=self.lr_decay) elif self.method == 'adadelta': self.optimizer = optim.Adadelta(self.params, lr=self.lr, weight_decay=self.lr_decay) elif self.method == 'adam': self.optimizer = optim.Adam(self.params, lr=self.lr, weight_decay=self.lr_decay) elif self.method == 'Nadam': self.optimizer = optim.NAdam(self.params, lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) elif self.method == 'adamw': self.optimizer = optim.AdamW(self.params, lr=self.lr, weight_decay=self.lr_decay) else: raise RuntimeError("Invalid optim method: " + self.method) def __init__(self, params, method, lr, clip, mode, factor, patience, steps_per_epoch, epochs, lr_decay=1, start_decay_at=None): self.params = params # careful: params may be a generator self.last_ppl = None self.lr = lr self.clip = clip self.method = method self.lr_decay = lr_decay self.start_decay_at = start_decay_at self.start_decay = False self._makeOptimizer() self.scheduler = ReduceLROnPlateau(self.optimizer, mode=mode, factor=factor, patience=patience, verbose=True) # self.scheduler = ExponentialLR(self.optimizer,gamma=0.9, verbose=True) # self.scheduler = CosineAnnealingLR(self.optimizer, max_lr=0.005, steps_per_epoch=steps_per_epoch, epochs=epochs, div_factor=1.1, final_div_factor=10) def step(self): # Compute gradients norm. grad_norm = 0 if self.clip is not None: torch.nn.utils.clip_grad_norm_(self.params, self.clip) # for param in self.params: # grad_norm += math.pow(param.grad.data.norm(), 2) # # grad_norm = math.sqrt(grad_norm) # if grad_norm > 0: # shrinkage = self.max_grad_norm / grad_norm # else: # shrinkage = 1. # # for param in self.params: # if shrinkage < 1: # param.grad.data.mul_(shrinkage) self.optimizer.step() return grad_norm def lronplateau(self, loss): self.scheduler.step(loss) def EXlr(self): self.scheduler.step() def CosineAnnealingLR(self): self.scheduler.step() # decay learning rate if val perf does not improve or we hit the start_decay_at limit def updateLearningRate(self, ppl, epoch): if self.start_decay_at is not None and epoch >= self.start_decay_at: self.start_decay = True if self.last_ppl is not None and ppl > self.last_ppl: self.start_decay = True if self.start_decay: self.lr = self.lr * self.lr_decay print("Decaying learning rate to %g" % self.lr) #only decay for one epoch self.start_decay = False self.last_ppl = ppl self._makeOptimizer()