Skip to content

14 annual embed#39

Open
meiertgrootes wants to merge 6 commits intomainfrom
14_annual_embed
Open

14 annual embed#39
meiertgrootes wants to merge 6 commits intomainfrom
14_annual_embed

Conversation

@meiertgrootes
Copy link
Copy Markdown
Collaborator

This pull request replaces positional encoding in time with a day-of-year (and hour-of-day) based temporal embedding.
This is expected to improve training results, in particular when using multiple years of data.

Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meiertgrootes thanks! this is very useful as now the model can capture seasonality/diurnal cycles. I left some comments. The major one is about how the Fourier time encoding is currently implemented, see my comments and let me know if something isn't clear.

Comment on lines +144 to +145
sinx = torch.sin(x)
cosx = torch.cos(x)
Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The day_timef is already cyclic [doy_sin, doy_cos, hod_sin, hod_cos], but here it applies sin/cos again i.e sin(k * already_sin_cos_value) and cos(k * already_sin_cos_value), this is not correct. For standard Fourier time features, raw phase should be used, so in the function add_month_day_dims, the related code can be fixed from:

    doy_sin = np.sin(2*np.pi*doy/doy_period)
    doy_cos = np.cos(2*np.pi*doy/doy_period)
    hod_sin = np.sin(2*np.pi*hod/hod_period)
    hod_cos = np.cos(2*np.pi*hod/hod_period)

to:

    doy_phase = (2*np.pi*doy/doy_period)
    hod_phase = (2*np.pi*hod/hod_period)

After this change, the dimension should be changed from 4 to 2 as well.

Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, if you want to keep the add_month_day_dims to return the sin/cos instead of raw phase, then class CyclicTimeEmbedding should change to use a nn.Sequential to project encoded features to embed_dim and remove sin/cos inside the forward.


# Positional encodings for days and months
self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days)
#self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given time_embed already implicitly captures day-within-month, then we can remove self.pos_days

T: number of temporal tokens per month after temporal patching (Tp)
H: spatial height after spatial patching
W: spatial width after spatial patching
time_features: (B,M,T,4) containing cyclically encoded DOY and HOD
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the dimension 4 to 2 after fixing the time_features

daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
data, where C is the number of channels (e.g., 1 for SST)
daily_mask: Boolean tensor of same shape as daily_data indicating missing values
daily_timef: Tensor of shape (B, M, T, 4) containing the cyclically encoded day-of-year
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the dimension 4 to 2 after fixing the time_features

Comment on lines +108 to +109
self.embed_dim = embed_dim
self.base_dim = base_dim
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.embed_dim = embed_dim
self.base_dim = base_dim

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused variables.

Comment thread climanet/utils.py
#-------------------------------------------

#determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace
doy_period = 365.0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leap years will be slightly mis-phased here. We can add a comment about it here, or we can calculate it as:
days_in_year = xr.where(time_indexed.dt.is_leap_year, 366.0, 365.0).fillna(365.0)

Comment thread climanet/utils.py
.unstack(time_dim)
.reindex(T=np.arange(1,32), M=month_keys)
)
#-------------------------------------------
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#-------------------------------------------

Comment thread climanet/utils.py
)

return daily_indexed, monthly_m, padded_days_mask
#-----------------------------------------
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#-----------------------------------------

Comment thread climanet/dataset.py
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 4)
daily_timef_tensor = torch.from_numpy(daily_timef_patch).float()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
daily_timef_tensor = torch.from_numpy(daily_timef_patch).float()
daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float()

Comment thread climanet/dataset.py
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W)
daily_timef_patch = self.daily_timef_np # (M,T,4)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
daily_timef_patch = self.daily_timef_np # (M,T,4)

@rogerkuou
Copy link
Copy Markdown
Collaborator

@meiertgrootes as I indicated in the Teams message I wont have time in the comming days to review this. I am fine with merging on Sarah's approval!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants