How to Design Transformer Model for Time-Series Forecasting
In this previous blog post, we explored the key aspects and benefits of transformer models, described how you can use pretrained models with MATLAB, and promised a blog post that shows you how to design transformers from scratch using built-in deep learning layers. In this blog post, I am going to provide you the code you need to design a transformer model for time-series forecasting.
The originally proposed architecture for transformers (Figure 1) includes encoder and decoder blocks. Since then, encoder-only (like the BERT model) and decoder-only (like GPT models) have been implemented. In this post, I will show you how to design a transformer model for time-series forecasting using only decoder blocks.
Figure 1: The original encoder-decoder architecture of a transformer model (adapted from Vaswani et al, 2017)
Decoder-Only Transformer Architecture
The architecture of the transformer model, which we are designing, is shown in Figure 2. The model includes two decoder blocks that use masked multi-head attention. In decoder-only transformers, masked self-attention is used to ensure that the model can only access previous tokens in the input sequence. In encoder-only transformers, self-attention mechanisms are used that attend to all tokens in the input sequence. By applying a mask over future positions in the input sequence, the model preserves the causality constraint necessary for tasks like text generation and time-series forecasting, where each output token must be generated in a left-to-right manner. Without masked self-attention, the model could access information from future tokens, which would violate the sequential nature of generation and introduce unintended data leakage into the forecasting. Figure 2: Architecture of decoder-only transformer model that we are designingDecoder-Only Transformer Design
Here, I am going to provide you MATLAB code to design, train, and analyze a decoder-only transformer architecture. Define the layers of the transformer network.numFeatures = 1; numHeads = 4; numKeyChannels = 256; feedforwardHiddenSize = 512; modelHiddenSize = 256; maxSequenceLength = 120; decoderLayers = [ sequenceInputLayer(numFeatures,Name="in") fullyConnectedLayer(modelHiddenSize,Name="embedding") positionEmbeddingLayer(modelHiddenSize,maxSequenceLength,Name="position_embed") additionLayer(2,Name="embed_add") layerNormalizationLayer(Name="embed_norm") selfAttentionLayer(numHeads,numKeyChannels,AttentionMask="causal") additionLayer(2,Name="attention_add") layerNormalizationLayer(Name="attention_norm") fullyConnectedLayer(feedforwardHiddenSize) geluLayer fullyConnectedLayer(modelHiddenSize) additionLayer(2,Name="feedforward_add") layerNormalizationLayer(Name="decoder1_norm") selfAttentionLayer(numHeads,numKeyChannels,AttentionMask="causal") additionLayer(2,Name="attention2_add") layerNormalizationLayer(Name="attention2_norm") fullyConnectedLayer(feedforwardHiddenSize) geluLayer fullyConnectedLayer(modelHiddenSize) additionLayer(2,Name="feedforward2_add") layerNormalizationLayer(Name="decoder2_norm") fullyConnectedLayer(numFeatures,Name="head")];Convert the layer array to a dlnetwork object.
net = dlnetwork(decoderLayers,Initialize=false);Connect the layers in the network.
net = connectLayers(net,"embedding","embed_add/in2"); net = connectLayers(net,"embed_norm","attention_add/in2"); net = connectLayers(net,"attention_norm","feedforward_add/in2"); net = connectLayers(net,"decoder1_norm","attention2_add/in2"); net = connectLayers(net,"attention2_norm","feedforward2_add/in2");Initialize the learnable and state parameters of the network.
net = initialize(net);Visualize and understand the architecture of the transformer network.
analyzeNetwork(net)
Conclusion
Many pretrained transformer models exist for natural language processing and computer vision tasks. In fact, such pretrained models are available for you in MATLAB (see BERT and ViT). However, time-series forecasting is a newer application for transformers with limited availability of pretrained models. Take advantage of the code provided in this post to build your own transformer model for time-series forecasting or adapt it for your task, and comment below to share your results.
댓글
댓글을 남기려면 링크 를 클릭하여 MathWorks 계정에 로그인하거나 계정을 새로 만드십시오.