Conversation
…ures to add_month_day_dims, to provide input for temporal encodings
SarahAlidoost
left a comment
There was a problem hiding this comment.
@meiertgrootes thanks! this is very useful as now the model can capture seasonality/diurnal cycles. I left some comments. The major one is about how the Fourier time encoding is currently implemented, see my comments and let me know if something isn't clear.
| sinx = torch.sin(x) | ||
| cosx = torch.cos(x) |
There was a problem hiding this comment.
The day_timef is already cyclic [doy_sin, doy_cos, hod_sin, hod_cos], but here it applies sin/cos again i.e sin(k * already_sin_cos_value) and cos(k * already_sin_cos_value), this is not correct. For standard Fourier time features, raw phase should be used, so in the function add_month_day_dims, the related code can be fixed from:
doy_sin = np.sin(2*np.pi*doy/doy_period)
doy_cos = np.cos(2*np.pi*doy/doy_period)
hod_sin = np.sin(2*np.pi*hod/hod_period)
hod_cos = np.cos(2*np.pi*hod/hod_period)to:
doy_phase = (2*np.pi*doy/doy_period)
hod_phase = (2*np.pi*hod/hod_period)After this change, the dimension should be changed from 4 to 2 as well.
There was a problem hiding this comment.
Alternatively, if you want to keep the add_month_day_dims to return the sin/cos instead of raw phase, then class CyclicTimeEmbedding should change to use a nn.Sequential to project encoded features to embed_dim and remove sin/cos inside the forward.
|
|
||
| # Positional encodings for days and months | ||
| self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) | ||
| #self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed |
There was a problem hiding this comment.
| #self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed |
There was a problem hiding this comment.
given time_embed already implicitly captures day-within-month, then we can remove self.pos_days
| T: number of temporal tokens per month after temporal patching (Tp) | ||
| H: spatial height after spatial patching | ||
| W: spatial width after spatial patching | ||
| time_features: (B,M,T,4) containing cyclically encoded DOY and HOD |
There was a problem hiding this comment.
change the dimension 4 to 2 after fixing the time_features
| daily_data: Tensor of shape (B, C, M, T, H, W) containing daily | ||
| data, where C is the number of channels (e.g., 1 for SST) | ||
| daily_mask: Boolean tensor of same shape as daily_data indicating missing values | ||
| daily_timef: Tensor of shape (B, M, T, 4) containing the cyclically encoded day-of-year |
There was a problem hiding this comment.
change the dimension 4 to 2 after fixing the time_features
| self.embed_dim = embed_dim | ||
| self.base_dim = base_dim |
There was a problem hiding this comment.
| self.embed_dim = embed_dim | |
| self.base_dim = base_dim |
| #------------------------------------------- | ||
|
|
||
| #determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace | ||
| doy_period = 365.0 |
There was a problem hiding this comment.
leap years will be slightly mis-phased here. We can add a comment about it here, or we can calculate it as:
days_in_year = xr.where(time_indexed.dt.is_leap_year, 366.0, 365.0).fillna(365.0)
| .unstack(time_dim) | ||
| .reindex(T=np.arange(1,32), M=month_keys) | ||
| ) | ||
| #------------------------------------------- |
There was a problem hiding this comment.
| #------------------------------------------- |
| ) | ||
|
|
||
| return daily_indexed, monthly_m, padded_days_mask | ||
| #----------------------------------------- |
There was a problem hiding this comment.
| #----------------------------------------- |
| # (1, M, T, H, W) | ||
| daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0) | ||
| # ( M, T, 4) | ||
| daily_timef_tensor = torch.from_numpy(daily_timef_patch).float() |
There was a problem hiding this comment.
| daily_timef_tensor = torch.from_numpy(daily_timef_patch).float() | |
| daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float() |
| daily_nan_mask = self.daily_nan_mask[ | ||
| :, :, i : i + ph, j : j + pw | ||
| ] # (M, T, H, W) | ||
| daily_timef_patch = self.daily_timef_np # (M,T,4) |
There was a problem hiding this comment.
| daily_timef_patch = self.daily_timef_np # (M,T,4) |
|
@meiertgrootes as I indicated in the Teams message I wont have time in the comming days to review this. I am fine with merging on Sarah's approval! |
This pull request replaces positional encoding in time with a day-of-year (and hour-of-day) based temporal embedding.
This is expected to improve training results, in particular when using multiple years of data.