Problem formulation (Easeformer Core)
$$ head_i = \mathbf{Softmax}\left(\frac{(h^t)^T\overline{W}^Q_i(h^t)^TW^{K^T}_i}{\sqrt{d}}\right)(h^t)^TW^V_i $$
# decoder input
if self.args.padding == 0:
dec_inp_zero = torch.zeros([batch_y.shape[0], self.args.pred_len, (batch_y.shape[-1]-1)]).float()
dec_inp = torch.cat((batch_y[:,self.args.label_len:self.args.label_len+self.args.pred_len,:1],dec_inp_zero), dim=2)
elif self.args.padding == 1:
dec_inp_one = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
dec_inp = torch.cat((batch_y[:,self.args.label_len:self.args.label_len+self.args.pred_len,:1],dec_inp_one), dim=2)
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder
if self.args.use_amp:
with torch.cuda.amp.autocast():
if self.args.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
else:
if self.args.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)