diff --git a/user-src/ipc-linux.h b/user-src/ipc-linux.h index 42fc3cd..eda7e68 100644 --- a/user-src/ipc-linux.h +++ b/user-src/ipc-linux.h @@ -540,20 +540,24 @@ static int parse_privkey(const struct nlattr *attr, void *data) switch (mnl_attr_get_type(attr)) { case WGDEVICE_A_PRIVATE_KEY: - *privkey->privkey_len = mnl_attr_get_payload_len(attr); - *privkey->privkey = malloc(*privkey->privkey_len); + if (mnl_attr_get_payload_len(attr) != WG_PRIVATE_KEY_LEN) + return MNL_CB_ERROR; + *privkey->privkey_len = WG_PRIVATE_KEY_LEN; + *privkey->privkey = malloc(WG_PRIVATE_KEY_LEN); if (! *privkey->privkey) return MNL_CB_ERROR; - memcpy(*privkey->privkey, mnl_attr_get_payload(attr), *privkey->privkey_len); + memcpy(*privkey->privkey, mnl_attr_get_payload(attr), WG_PRIVATE_KEY_LEN); break; case WGDEVICE_A_PUBLIC_KEY: if (! privkey->pubkey) return MNL_CB_OK; - *privkey->pubkey_len = mnl_attr_get_payload_len(attr); - *privkey->pubkey = malloc(*privkey->pubkey_len); + if (mnl_attr_get_payload_len(attr) != WG_PUBLIC_KEY_LEN) + return MNL_CB_ERROR; + *privkey->pubkey_len = WG_PUBLIC_KEY_LEN; + *privkey->pubkey = malloc(WG_PUBLIC_KEY_LEN); if (! *privkey->pubkey) return MNL_CB_ERROR; - memcpy(*privkey->pubkey, mnl_attr_get_payload(attr), *privkey->pubkey_len); + memcpy(*privkey->pubkey, mnl_attr_get_payload(attr), WG_PUBLIC_KEY_LEN); break; } @@ -605,8 +609,10 @@ static int kernel_generate_privkey(uint8_t **privkey, size_t *privkey_len, uint8 } out: - if (nlg) + if (nlg) { mnlg_socket_close(nlg); + nlg = NULL; + } if (ret) { if (*privkey) { memzero_explicit(*privkey, *privkey_len); @@ -641,11 +647,13 @@ static int parse_pubkey(const struct nlattr *attr, void *data) case WGDEVICE_A_PUBLIC_KEY: if (! pubkey->pubkey) return MNL_CB_OK; - *pubkey->pubkey_len = mnl_attr_get_payload_len(attr); - *pubkey->pubkey = malloc(*pubkey->pubkey_len); + if (mnl_attr_get_payload_len(attr) != WG_PUBLIC_KEY_LEN) + return MNL_CB_ERROR; + *pubkey->pubkey_len = WG_PUBLIC_KEY_LEN; + *pubkey->pubkey = malloc(WG_PUBLIC_KEY_LEN); if (! *pubkey->pubkey) return MNL_CB_ERROR; - memcpy(*pubkey->pubkey, mnl_attr_get_payload(attr), *pubkey->pubkey_len); + memcpy(*pubkey->pubkey, mnl_attr_get_payload(attr), WG_PUBLIC_KEY_LEN); break; } @@ -691,8 +699,10 @@ static int kernel_derive_pubkey(const uint8_t *privkey, size_t privkey_len, uint } out: - if (nlg) + if (nlg) { mnlg_socket_close(nlg); + nlg = NULL; + } if (ret) { if (*pubkey) { free(*pubkey); @@ -717,11 +727,13 @@ static int parse_psk(const struct nlattr *attr, void *data) switch (mnl_attr_get_type(attr)) { case WGDEVICE_A_PRESHARED_KEY: - *psk->psk_len = mnl_attr_get_payload_len(attr); - *psk->psk = malloc(*psk->psk_len); + if (mnl_attr_get_payload_len(attr) != WG_SYMMETRIC_KEY_LEN) + return MNL_CB_ERROR; + *psk->psk_len = WG_SYMMETRIC_KEY_LEN; + *psk->psk = malloc(WG_SYMMETRIC_KEY_LEN); if (! *psk->psk) return MNL_CB_ERROR; - memcpy(*psk->psk, mnl_attr_get_payload(attr), *psk->psk_len); + memcpy(*psk->psk, mnl_attr_get_payload(attr), WG_SYMMETRIC_KEY_LEN); break; } @@ -766,8 +778,10 @@ static int kernel_generate_psk(uint8_t **psk, size_t *psk_len) } out: - if (nlg) + if (nlg) { mnlg_socket_close(nlg); + nlg = NULL; + } if (ret) { if (*psk) { memzero_explicit(*psk, *psk_len);