diff --git a/reflections.go b/reflections.go index 90c0ac2..412566a 100644 --- a/reflections.go +++ b/reflections.go @@ -104,6 +104,15 @@ func GetFieldTag(obj interface{}, fieldName, tagKey string) (string, error) { // The `obj` parameter must be a `struct`, or a `pointer` to one. If the `obj` parameter doesn't have a field tagged // with the `tagKey`, and the matching `tagValue`, this function returns an error. func GetFieldNameByTagValue(obj interface{}, tagKey, tagValue string) (string, error) { + return getFieldNameByTagValue(obj, tagKey, tagValue, false) +} + +// GetFieldNameByTagValueDeep considers embedded fields too while scanning for a matching a matching `{tagKey}:"{tagValue}"` tag +func GetFieldNameByTagValueDeep(obj interface{}, tagKey, tagValue string) (string, error) { + return getFieldNameByTagValue(obj, tagKey, tagValue, true) +} + +func getFieldNameByTagValue(obj interface{}, tagKey, tagValue string, deep bool) (string, error) { if !isSupportedType(obj, []reflect.Kind{reflect.Struct, reflect.Ptr}) { return "", fmt.Errorf("cannot use GetFieldByTag on a non-struct interface: %w", ErrUnsupportedType) } @@ -114,8 +123,21 @@ func GetFieldNameByTagValue(obj interface{}, tagKey, tagValue string) (string, e for i := range fieldsCount { structField := objType.Field(i) - if structField.Tag.Get(tagKey) == tagValue { - return structField.Name, nil + if isExportableField(structField) { + if !deep || !structField.Anonymous { + if structField.Tag.Get(tagKey) == tagValue { + return structField.Name, nil + } + continue + } + + fieldValue := objValue.Field(i) + m, err := getFieldNameByTagValue(fieldValue.Interface(), tagKey, tagValue, deep) + if err != nil { + return "", fmt.Errorf("cannot get items in %s: %w", structField.Name, err) + } else if m != "" { + return m, nil + } } } diff --git a/reflections_test.go b/reflections_test.go index 6cf8975..3591f9b 100644 --- a/reflections_test.go +++ b/reflections_test.go @@ -553,6 +553,45 @@ func TestGetFieldNameByTagValue_on_non_existing_tag(t *testing.T) { require.Error(t, errTagKeyValue) } +func TestGetFieldNameByTagValueDeep(t *testing.T) { + t.Parallel() + + type child struct { + TestStruct + } + dummyStruct := child{} + + tagJSON := "dummytag" + field, err := GetFieldNameByTagValueDeep(dummyStruct, "test", tagJSON) + + require.NoError(t, err) + assert.Equal(t, "Dummy", field) +} + +func TestGetFieldNameByTagValueDeep_on_non_existing_tag(t *testing.T) { + t.Parallel() + + type child struct { + TestStruct + } + dummyStruct := child{} + + // non existing tag value with an existing tag key + tagJSON := "tag" + _, errTagValue := GetFieldNameByTagValueDeep(dummyStruct, "test", tagJSON) + require.Error(t, errTagValue) + + // non existing tag key with an existing tag value + tagJSON = "dummytag" + _, errTagKey := GetFieldNameByTagValueDeep(dummyStruct, "json", tagJSON) + require.Error(t, errTagKey) + + // non existing tag key and value + tagJSON = "tag" + _, errTagKeyValue := GetFieldNameByTagValueDeep(dummyStruct, "json", tagJSON) + require.Error(t, errTagKeyValue) +} + //nolint:unused func TestTags_deep(t *testing.T) { t.Parallel()