From 23b9437c08a9c41450ca9a4e45d4002498500274 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Thu, 4 Aug 2022 10:40:48 +0200 Subject: [PATCH] Avoid panics when LDAP users miss required attributes --- services/graph/pkg/identity/ldap.go | 46 +++++++++++++++++------- services/graph/pkg/identity/ldap_test.go | 43 +++++++++++++++++++++- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/services/graph/pkg/identity/ldap.go b/services/graph/pkg/identity/ldap.go index eaab3946c..24f8678c6 100644 --- a/services/graph/pkg/identity/ldap.go +++ b/services/graph/pkg/identity/ldap.go @@ -392,9 +392,12 @@ func (i *LDAP) GetUser(ctx context.Context, nameOrID string, queryParam url.Valu if err != nil { return nil, err } + u := i.createUserModelFromLDAP(e) + if u == nil { + return nil, errNotFound + } sel := strings.Split(queryParam.Get("$select"), ",") exp := strings.Split(queryParam.Get("$expand"), ",") - u := i.createUserModelFromLDAP(e) if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") { userGroups, err := i.getGroupsForUser(e.DN) if err != nil { @@ -455,6 +458,10 @@ func (i *LDAP) GetUsers(ctx context.Context, queryParam url.Values) ([]*libregra sel := strings.Split(queryParam.Get("$select"), ",") exp := strings.Split(queryParam.Get("$expand"), ",") u := i.createUserModelFromLDAP(e) + // Skip invalid LDAP users + if u == nil { + continue + } if slices.Contains(sel, "memberOf") || slices.Contains(exp, "memberOf") { userGroups, err := i.getGroupsForUser(e.DN) if err != nil { @@ -505,8 +512,10 @@ func (i *LDAP) GetGroup(ctx context.Context, nameOrID string, queryParam url.Val } if len(members) > 0 { m := make([]libregraph.User, 0, len(members)) - for _, u := range members { - m = append(m, *i.createUserModelFromLDAP(u)) + for _, ue := range members { + if u := i.createUserModelFromLDAP(ue); u != nil { + m = append(m, *u) + } } g.Members = m } @@ -688,8 +697,10 @@ func (i *LDAP) GetGroups(ctx context.Context, queryParam url.Values) ([]*libregr } if len(members) > 0 { m := make([]libregraph.User, 0, len(members)) - for _, u := range members { - m = append(m, *i.createUserModelFromLDAP(u)) + for _, ue := range members { + if u := i.createUserModelFromLDAP(ue); u != nil { + m = append(m, *u) + } } g.Members = m } @@ -706,14 +717,15 @@ func (i *LDAP) GetGroupMembers(ctx context.Context, groupID string) ([]*libregra return nil, err } - result := []*libregraph.User{} - memberEntries, err := i.expandLDAPGroupMembers(ctx, e) + result := make([]*libregraph.User, 0, len(memberEntries)) if err != nil { return nil, err } for _, member := range memberEntries { - result = append(result, i.createUserModelFromLDAP(member)) + if u := i.createUserModelFromLDAP(member); u != nil { + result = append(result, u) + } } return result, nil @@ -909,12 +921,20 @@ func (i *LDAP) createUserModelFromLDAP(e *ldap.Entry) *libregraph.User { if e == nil { return nil } - return &libregraph.User{ - DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.displayName)), - Mail: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.mail)), - OnPremisesSamAccountName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.userName)), - Id: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.id)), + + opsan := e.GetEqualFoldAttributeValue(i.userAttributeMap.userName) + id := e.GetEqualFoldAttributeValue(i.userAttributeMap.id) + + if id != "" && opsan != "" { + return &libregraph.User{ + DisplayName: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.displayName)), + Mail: pointerOrNil(e.GetEqualFoldAttributeValue(i.userAttributeMap.mail)), + OnPremisesSamAccountName: &opsan, + Id: &id, + } } + i.logger.Warn().Str("dn", e.DN).Msg("Invalid User. Missing username or id attribute") + return nil } func (i *LDAP) createGroupModelFromLDAP(e *ldap.Entry) *libregraph.Group { diff --git a/services/graph/pkg/identity/ldap_test.go b/services/graph/pkg/identity/ldap_test.go index 7f561e786..b442a9065 100644 --- a/services/graph/pkg/identity/ldap_test.go +++ b/services/graph/pkg/identity/ldap_test.go @@ -42,11 +42,20 @@ var userEntry = ldap.NewEntry("uid=user", "mail": {"user@example"}, "entryuuid": {"abcd-defg"}, }) +var invalidUserEntry = ldap.NewEntry("uid=user", + map[string][]string{ + "uid": {"invalid"}, + "displayname": {"DisplayName"}, + "mail": {"user@example"}, + }) var groupEntry = ldap.NewEntry("cn=group", map[string][]string{ "cn": {"group"}, "entryuuid": {"abcd-defg"}, - "member": {"uid=user,ou=people,dc=test"}, + "member": { + "uid=user,ou=people,dc=test", + "uid=invalid,ou=people,dc=test", + }, }) var invalidGroupEntry = ldap.NewEntry("cn=invalid", map[string][]string{ @@ -196,6 +205,21 @@ func TestGetUser(t *testing.T) { } else if *u.Id != userEntry.GetEqualFoldAttributeValue(b.userAttributeMap.id) { t.Errorf("Expected GetUser to return a valid user") } + + // Mock invalid Search Result + lm = &mocks.Client{} + lm.On("Search", mock.Anything). + Return( + &ldap.SearchResult{ + Entries: []*ldap.Entry{invalidUserEntry}, + }, + nil) + + b, _ = getMockedBackend(lm, lconfig, &logger) + u, err = b.GetUser(context.Background(), "invalid", nil) + if err == nil || err.Error() != "itemNotFound" { + t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) + } } func TestGetUsers(t *testing.T) { @@ -298,9 +322,17 @@ func TestGetGroup(t *testing.T) { Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, Controls: []ldap.Control(nil), } + sr3 := &ldap.SearchRequest{ + BaseDN: "uid=invalid,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil) lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil) + lm.On("Search", sr3).Return(&ldap.SearchResult{Entries: []*ldap.Entry{invalidUserEntry}}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err = b.GetGroup(context.Background(), "group", nil) if err != nil { @@ -385,9 +417,18 @@ func TestGetGroups(t *testing.T) { Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, Controls: []ldap.Control(nil), } + sr3 := &ldap.SearchRequest{ + BaseDN: "uid=invalid,ou=people,dc=test", + SizeLimit: 1, + Filter: "(objectclass=*)", + Attributes: []string{"displayname", "entryUUID", "mail", "uid"}, + Controls: []ldap.Control(nil), + } + for _, param := range []url.Values{queryParamSelect, queryParamExpand} { lm.On("Search", sr1).Return(&ldap.SearchResult{Entries: []*ldap.Entry{groupEntry}}, nil) lm.On("Search", sr2).Return(&ldap.SearchResult{Entries: []*ldap.Entry{userEntry}}, nil) + lm.On("Search", sr3).Return(&ldap.SearchResult{Entries: []*ldap.Entry{invalidUserEntry}}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err = b.GetGroups(context.Background(), param) switch {