From 207924f0d5e5e18f2fc52b0a1c3ce522acfa406d Mon Sep 17 00:00:00 2001 From: Chris F Ravenscroft Date: Sat, 16 May 2026 16:21:51 -0700 Subject: [PATCH] Harden LDAP control decoding against malformed BER input Change DecodeControl to return (Control, error) and validate control structure and value types instead of relying on unchecked access. Handle decode failures on both server and client paths, returning protocol errors for bad request controls and surfacing response decode errors. --- control.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++------- search.go | 6 +++++- server.go | 38 +++++++++++++++++++++++++++++++++++-- 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/control.go b/control.go index a42b5bb..463f979 100644 --- a/control.go +++ b/control.go @@ -100,8 +100,24 @@ func FindControl(controls []Control, controlType string) Control { return nil } -func DecodeControl(packet *ber.Packet) Control { - ControlType := packet.Children[0].Value.(string) +func DecodeControl(packet *ber.Packet) (control Control, err error) { + defer func() { + if r := recover(); r != nil { + control = nil + err = fmt.Errorf("ldap: failed to decode control: %v", r) + } + }() + + if packet == nil || len(packet.Children) == 0 { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control packet") + } + if packet.Children[0] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control packet") + } + ControlType, ok := packet.Children[0].Value.(string) + if !ok { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control type") + } packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" c := new(ControlString) c.ControlType = ControlType @@ -111,8 +127,18 @@ func DecodeControl(packet *ber.Packet) Control { value := packet.Children[1] if len(packet.Children) == 3 { value = packet.Children[2] + if packet.Children[1] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control criticality") + } packet.Children[1].Description = "Criticality" - c.Criticality = packet.Children[1].Value.(bool) + criticality, ok := packet.Children[1].Value.(bool) + if !ok { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control criticality") + } + c.Criticality = criticality + } + if value == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control value") } value.Description = "Control Value" @@ -126,18 +152,33 @@ func DecodeControl(packet *ber.Packet) Control { value.Value = nil value.AppendChild(valueChildren) } + // Exactly one controlValue, RFC2696 + if len(value.Children) != 1 || value.Children[0] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed paging control value") + } value = value.Children[0] value.Description = "Search Control Value" + if len(value.Children) < 2 || value.Children[0] == nil || value.Children[1] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed paging control value") + } value.Children[0].Description = "Paging Size" value.Children[1].Description = "Cookie" - c.PagingSize = uint32(value.Children[0].Value.(int64)) + pagingSize, ok := value.Children[0].Value.(int64) + if !ok || pagingSize < 0 || pagingSize > int64(^uint32(0)) { + return nil, fmt.Errorf("ldap: failed to decode control: malformed paging control size") + } + c.PagingSize = uint32(pagingSize) c.Cookie = value.Children[1].Data.Bytes() value.Children[1].Value = c.Cookie - return c + return c, nil + } + controlValue, ok := value.Value.(string) + if !ok { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control value") } - c.ControlValue = value.Value.(string) + c.ControlValue = controlValue } - return c + return c, nil } func NewControlString(controlType string, criticality bool, controlValue string) *ControlString { diff --git a/search.go b/search.go index 7b54805..3ca8668 100644 --- a/search.go +++ b/search.go @@ -331,7 +331,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { - result.Controls = append(result.Controls, DecodeControl(child)) + control, err := DecodeControl(child) + if err != nil { + return result, NewError(ErrorNetwork, err) + } + result.Controls = append(result.Controls, control) } } foundSearchResultDone = true diff --git a/server.go b/server.go index 161e9dd..5dbc91d 100644 --- a/server.go +++ b/server.go @@ -256,7 +256,18 @@ handler: controls := []Control{} if len(packet.Children) > 2 { for _, child := range packet.Children[2].Children { - controls = append(controls, DecodeControl(child)) + control, err := DecodeControl(child) + if err != nil { + log.Printf("DecodeControl error %s", err.Error()) + responsePacket := encodeProtocolErrorResponse(messageID, req.Tag) + if responsePacket != nil { + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } + } + break handler + } + controls = append(controls, control) } } @@ -324,7 +335,7 @@ handler: case ApplicationExtendedRequest: var tlsConn *tls.Conn if n := len(req.Children); n == 1 || n == 2 { - if name := ber.DecodeString(req.Children[0].Data.Bytes()); name == oidStartTLS && server.TLSConfig != nil { + if name := ber.DecodeString(req.Children[0].Data.Bytes()); name == oidStartTLS { tlsConn = tls.Server(conn, server.TLSConfig) } } @@ -401,6 +412,29 @@ func sendPacket(conn net.Conn, packet *ber.Packet) error { return nil } +func encodeProtocolErrorResponse(messageID uint64, requestType ber.Tag) *ber.Packet { + switch requestType { + case ApplicationBindRequest: + return encodeBindResponse(messageID, LDAPResultProtocolError) + case ApplicationSearchRequest: + return encodeSearchDone(messageID, LDAPResultProtocolError) + case ApplicationModifyRequest: + return encodeLDAPResponse(messageID, ApplicationModifyResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationAddRequest: + return encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationDelRequest: + return encodeLDAPResponse(messageID, ApplicationDelResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationModifyDNRequest: + return encodeLDAPResponse(messageID, ApplicationModifyDNResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationCompareRequest: + return encodeLDAPResponse(messageID, ApplicationCompareResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationExtendedRequest: + return encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + default: + return nil + } +} + func routeFunc(dn string, funcNames []string) string { bestPick := "" bestPickWeight := 0