diff --git a/control.go b/control.go index a42b5bb..463f979 100644 --- a/control.go +++ b/control.go @@ -100,8 +100,24 @@ func FindControl(controls []Control, controlType string) Control { return nil } -func DecodeControl(packet *ber.Packet) Control { - ControlType := packet.Children[0].Value.(string) +func DecodeControl(packet *ber.Packet) (control Control, err error) { + defer func() { + if r := recover(); r != nil { + control = nil + err = fmt.Errorf("ldap: failed to decode control: %v", r) + } + }() + + if packet == nil || len(packet.Children) == 0 { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control packet") + } + if packet.Children[0] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control packet") + } + ControlType, ok := packet.Children[0].Value.(string) + if !ok { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control type") + } packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" c := new(ControlString) c.ControlType = ControlType @@ -111,8 +127,18 @@ func DecodeControl(packet *ber.Packet) Control { value := packet.Children[1] if len(packet.Children) == 3 { value = packet.Children[2] + if packet.Children[1] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control criticality") + } packet.Children[1].Description = "Criticality" - c.Criticality = packet.Children[1].Value.(bool) + criticality, ok := packet.Children[1].Value.(bool) + if !ok { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control criticality") + } + c.Criticality = criticality + } + if value == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control value") } value.Description = "Control Value" @@ -126,18 +152,33 @@ func DecodeControl(packet *ber.Packet) Control { value.Value = nil value.AppendChild(valueChildren) } + // Exactly one controlValue, RFC2696 + if len(value.Children) != 1 || value.Children[0] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed paging control value") + } value = value.Children[0] value.Description = "Search Control Value" + if len(value.Children) < 2 || value.Children[0] == nil || value.Children[1] == nil { + return nil, fmt.Errorf("ldap: failed to decode control: malformed paging control value") + } value.Children[0].Description = "Paging Size" value.Children[1].Description = "Cookie" - c.PagingSize = uint32(value.Children[0].Value.(int64)) + pagingSize, ok := value.Children[0].Value.(int64) + if !ok || pagingSize < 0 || pagingSize > int64(^uint32(0)) { + return nil, fmt.Errorf("ldap: failed to decode control: malformed paging control size") + } + c.PagingSize = uint32(pagingSize) c.Cookie = value.Children[1].Data.Bytes() value.Children[1].Value = c.Cookie - return c + return c, nil + } + controlValue, ok := value.Value.(string) + if !ok { + return nil, fmt.Errorf("ldap: failed to decode control: malformed control value") } - c.ControlValue = value.Value.(string) + c.ControlValue = controlValue } - return c + return c, nil } func NewControlString(controlType string, criticality bool, controlValue string) *ControlString { diff --git a/search.go b/search.go index 7b54805..3ca8668 100644 --- a/search.go +++ b/search.go @@ -331,7 +331,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { - result.Controls = append(result.Controls, DecodeControl(child)) + control, err := DecodeControl(child) + if err != nil { + return result, NewError(ErrorNetwork, err) + } + result.Controls = append(result.Controls, control) } } foundSearchResultDone = true diff --git a/server.go b/server.go index 161e9dd..5dbc91d 100644 --- a/server.go +++ b/server.go @@ -256,7 +256,18 @@ handler: controls := []Control{} if len(packet.Children) > 2 { for _, child := range packet.Children[2].Children { - controls = append(controls, DecodeControl(child)) + control, err := DecodeControl(child) + if err != nil { + log.Printf("DecodeControl error %s", err.Error()) + responsePacket := encodeProtocolErrorResponse(messageID, req.Tag) + if responsePacket != nil { + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } + } + break handler + } + controls = append(controls, control) } } @@ -324,7 +335,7 @@ handler: case ApplicationExtendedRequest: var tlsConn *tls.Conn if n := len(req.Children); n == 1 || n == 2 { - if name := ber.DecodeString(req.Children[0].Data.Bytes()); name == oidStartTLS && server.TLSConfig != nil { + if name := ber.DecodeString(req.Children[0].Data.Bytes()); name == oidStartTLS { tlsConn = tls.Server(conn, server.TLSConfig) } } @@ -401,6 +412,29 @@ func sendPacket(conn net.Conn, packet *ber.Packet) error { return nil } +func encodeProtocolErrorResponse(messageID uint64, requestType ber.Tag) *ber.Packet { + switch requestType { + case ApplicationBindRequest: + return encodeBindResponse(messageID, LDAPResultProtocolError) + case ApplicationSearchRequest: + return encodeSearchDone(messageID, LDAPResultProtocolError) + case ApplicationModifyRequest: + return encodeLDAPResponse(messageID, ApplicationModifyResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationAddRequest: + return encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationDelRequest: + return encodeLDAPResponse(messageID, ApplicationDelResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationModifyDNRequest: + return encodeLDAPResponse(messageID, ApplicationModifyDNResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationCompareRequest: + return encodeLDAPResponse(messageID, ApplicationCompareResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + case ApplicationExtendedRequest: + return encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, LDAPResultCodeMap[LDAPResultProtocolError]) + default: + return nil + } +} + func routeFunc(dn string, funcNames []string) string { bestPick := "" bestPickWeight := 0