TimesNet-Gen / TimesNet.py
Barisylmz's picture
Upload 4 files
0dfdc08 verified
raw
history blame
17.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import numpy as np
# Basit embedding ve conv blocks - layers klasörü olmadan
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, seq_len=6000):
super(DataEmbedding, self).__init__()
self.c_in = c_in
self.d_model = d_model
self.embed_type = embed_type
self.freq = freq
self.seq_len = seq_len
# Basit linear embedding
self.value_embedding = nn.Linear(c_in, d_model)
# Position embedding'i seq_len'e göre oluştur
self.position_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = self.value_embedding(x)
# Position embedding'i input boyutuna göre crop et
# seq_len'e göre oluşturulduğu için genelde uyumlu olacak
if x.size(1) <= self.position_embedding.size(1):
x = x + self.position_embedding[:, :x.size(1), :]
else:
# Eğer input daha büyükse, position embedding'i extend et
x = x + self.position_embedding
remaining_length = x.size(1) - self.position_embedding.size(1)
if remaining_length > 0:
# Sinusoidal position encoding ekle
pos_encoding = self._get_sinusoidal_encoding(remaining_length, self.d_model)
pos_encoding = pos_encoding.unsqueeze(0).to(x.device)
x[:, self.position_embedding.size(1):, :] += pos_encoding
return self.dropout(x)
def _get_sinusoidal_encoding(self, length, d_model):
"""Sinusoidal position encoding oluştur"""
position = torch.arange(length).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
pos_encoding = torch.zeros(length, d_model)
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
return pos_encoding
class Inception_Block_V1(nn.Module):
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
super(Inception_Block_V1, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
kernels = []
for i in range(self.num_kernels):
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
res_list = []
for i, kernel in enumerate(self.kernels):
res_list.append(kernel(x))
res = torch.stack(res_list, dim=-1).mean(-1)
return res
def FFT_for_Period(x, k=2):
# [B, T, C]
xf = torch.fft.rfft(x, dim=1)
# find period by amplitudes
frequency_list = abs(xf).mean(0).mean(-1)
frequency_list[0] = 0
_, top_list = torch.topk(frequency_list, k)
top_list = top_list.detach().cpu().numpy()
period = x.shape[1] // top_list
return period, abs(xf).mean(-1)[:, top_list]
class TimesBlock(nn.Module):
def __init__(self, configs):
super(TimesBlock, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.k = configs.top_k
# parameter-efficient design
self.conv = nn.Sequential(
Inception_Block_V1(configs.d_model, configs.d_ff,
num_kernels=configs.num_kernels),
nn.GELU(),
Inception_Block_V1(configs.d_ff, configs.d_model,
num_kernels=configs.num_kernels)
)
def forward(self, x):
B, T, N = x.size() #B: batch size T: length of time series N:number of features
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
period = period_list[i]
# padding
if (self.seq_len + self.pred_len) % period != 0:
length = (
((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = (self.seq_len + self.pred_len)
out = x
# reshape
out = out.reshape(B, length // period, period,
N).permute(0, 3, 1, 2).contiguous()
# 2D conv: from 1d Variation to 2d Variation
out = self.conv(out)
# reshape back
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, :(self.seq_len + self.pred_len), :])
res = torch.stack(res, dim=-1)
# adaptive aggregation
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(
1).unsqueeze(1).repeat(1, T, N, 1)
res = torch.sum(res * period_weight, -1)
# residual connection
res = res + x
return res
class Model(nn.Module):
"""
Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
"""
def __init__(self, configs):
super(Model, self).__init__()
self.configs = configs
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.pred_len
self.model = nn.ModuleList([TimesBlock(configs)
for _ in range(configs.e_layers)])
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout, configs.seq_len)
self.layer = configs.e_layers
self.layer_norm = nn.LayerNorm(configs.d_model)
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.predict_linear = nn.Linear(
self.seq_len, self.pred_len + self.seq_len)
self.projection = nn.Linear(
configs.d_model, configs.c_out, bias=True)
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
self.projection = nn.Linear(
configs.d_model, configs.c_out, bias=True)
# Transfer learning için P-S prediction heads (sadece gerektiğinde eklenir)
if hasattr(configs, 'use_ps_heads') and configs.use_ps_heads:
# Skip attention for memory efficiency - use only pooling
# Multi-scale feature extraction (reduced sizes for memory)
self.multi_scale_pools = nn.ModuleList([
nn.AdaptiveAvgPool1d(16), # Local patterns (reduced)
nn.AdaptiveAvgPool1d(4), # Medium patterns
nn.AdaptiveAvgPool1d(1), # Global patterns
])
# Feature fusion - calculate exact dimension
# Pool sizes: 16 + 4 + 1 = 21, so total dim = d_model * 21
fusion_dim = configs.d_model * (16 + 4 + 1) # Exact calculation
self.feature_fusion = nn.Sequential(
nn.Linear(fusion_dim, configs.d_model),
nn.ReLU(),
nn.Dropout(configs.dropout)
)
# Separate P and S regression heads
self.p_regression_head = nn.Sequential(
nn.Linear(configs.d_model, 128),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(64, 1) # P time only
)
self.s_regression_head = nn.Sequential(
nn.Linear(configs.d_model, 128),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(64, 1) # S time only
)
# Separate P and S classification heads
self.p_classification_head = nn.Sequential(
nn.Linear(configs.d_model, 64),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(32, 1), # P exists/not
nn.Sigmoid()
)
self.s_classification_head = nn.Sequential(
nn.Linear(configs.d_model, 64),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(configs.dropout),
nn.Linear(32, 1), # S exists/not
nn.Sigmoid()
)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
configs.d_model * configs.seq_len, configs.num_class)
def anomaly_detection(self, x_enc):
# Transfer learning için P-S heads varsa - SADECE ONLARI KULLAN
if hasattr(self, 'p_regression_head'):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# embedding
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# Skip attention for memory - use direct multi-scale pooling
# Multi-scale feature extraction directly on TimesNet output
enc_out_transposed = enc_out.permute(0, 2, 1) # (B, d_model, T)
multi_scale_features = []
# Manual pooling for large sequences to avoid CUDA memory issues
pool_sizes = [16, 4, 1] # Target pool sizes
for i, target_size in enumerate(pool_sizes):
T = enc_out_transposed.size(2) # Sequence length
if T >= 8000: # Very large - use manual avg pooling
# Manual average pooling
window_size = T // target_size
if window_size > 0:
# Reshape and average
# (B, d_model, T) -> (B, d_model, target_size, window_size)
trimmed_T = (T // window_size) * window_size
trimmed = enc_out_transposed[:, :, :trimmed_T]
reshaped = trimmed.view(trimmed.size(0), trimmed.size(1), target_size, window_size)
pooled = reshaped.mean(dim=3) # Average over window
else:
# Fallback: simple reshape
pooled = enc_out_transposed[:, :, :target_size] if T >= target_size else enc_out_transposed
else:
# Use normal adaptive pooling for smaller sequences
pool = self.multi_scale_pools[i]
pooled = pool(enc_out_transposed) # (B, d_model, pool_size)
flattened = pooled.flatten(1) # (B, d_model * pool_size)
multi_scale_features.append(flattened)
# Concatenate multi-scale features
fused_features = torch.cat(multi_scale_features, dim=1) # (B, d_model * 3)
# Feature fusion
final_features = self.feature_fusion(fused_features) # (B, d_model)
# Separate P and S predictions
p_time = self.p_regression_head(final_features) # (B, 1)
s_time = self.s_regression_head(final_features) # (B, 1)
ps_times = torch.cat([p_time, s_time], dim=1) # (B, 2)
# Separate P and S classifications
p_class = self.p_classification_head(final_features) # (B, 1)
s_class = self.s_classification_head(final_features) # (B, 1)
ps_classification = torch.cat([p_class, s_class], dim=1) # (B, 2)
return ps_times, ps_classification
else:
# Orijinal anomaly detection (reconstruction)
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# embedding
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# porject back
dec_out = self.projection(enc_out)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
return dec_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
result = self.anomaly_detection(x_enc)
return result # [B, L, D] veya [B, L, D], [B, 2], [B, 1]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# porject back
dec_out = self.projection(enc_out)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
return dec_out
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
# Normalization from Non-stationary Transformer
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
means = means.unsqueeze(1).detach()
x_enc = x_enc - means
x_enc = x_enc.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(1).detach()
x_enc /= stdev
# embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# porject back
dec_out = self.projection(enc_out)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
return dec_out
def classification(self, x_enc, x_mark_enc):
# embedding
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# Output
# the output transformer encoder/decoder embeddings don't include non-linearity
output = self.act(enc_out)
output = self.dropout(output)
# zero-out padding embeddings
output = output * x_mark_enc.unsqueeze(-1)
# (batch_size, seq_length * d_model)
output = output.reshape(output.shape[0], -1)
output = self.projection(output) # (batch_size, num_classes)
return output