diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..26877d8c Binary files /dev/null and b/.DS_Store differ diff --git a/cmd/protoc-gen-doc/main.go b/cmd/protoc-gen-doc/main.go index 56e42ce4..eecc5d62 100644 --- a/cmd/protoc-gen-doc/main.go +++ b/cmd/protoc-gen-doc/main.go @@ -20,6 +20,7 @@ import ( "os" gendoc "github.com/pseudomuto/protoc-gen-doc" + _ "github.com/pseudomuto/protoc-gen-doc/extensions/generic" // imports the generic extension handler _ "github.com/pseudomuto/protoc-gen-doc/extensions/google_api_http" // imported for side effects _ "github.com/pseudomuto/protoc-gen-doc/extensions/lyft_validate" // imported for side effects _ "github.com/pseudomuto/protoc-gen-doc/extensions/validator_field" // imported for side effects @@ -50,4 +51,4 @@ func HandleFlags(f *Flags) bool { } return true -} +} \ No newline at end of file diff --git a/examples/.DS_Store b/examples/.DS_Store new file mode 100644 index 00000000..4327e78c Binary files /dev/null and b/examples/.DS_Store differ diff --git a/extensions/.DS_Store b/extensions/.DS_Store new file mode 100644 index 00000000..71b6c17f Binary files /dev/null and b/extensions/.DS_Store differ diff --git a/extensions/extensions.go b/extensions/extensions.go index 5e48c72a..345f9718 100644 --- a/extensions/extensions.go +++ b/extensions/extensions.go @@ -1,35 +1,70 @@ // Package extensions implements a system for working with extended options. package extensions +import ( + "strings" + "github.com/golang/protobuf/proto" + "fmt" + "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + // Transformer functions for transforming payloads of an extension option into // something that can be rendered by a template. type Transformer func(payload interface{}) interface{} var transformers = make(map[string]Transformer) +var defaultTransformer Transformer +var transformNameAliases = make(map[string]string) // SetTransformer sets the transformer function for the given extension name func SetTransformer(extensionName string, f Transformer) { transformers[extensionName] = f } -// Transform the extensions using the registered transformers. +// SetNameAlias sets an alias for an extension name, allowing different keys +// to map to the same transformer +func SetNameAlias(fullName, aliasTo string) { + transformNameAliases[fullName] = aliasTo +} + +// SetDefaultTransformer sets the default transformer function for any +// extension that doesn't have a specific transformer registered +func SetDefaultTransformer(f Transformer) { + defaultTransformer = f +} + +// Generic Transform function func Transform(extensions map[string]interface{}) map[string]interface{} { if extensions == nil { return nil } + out := make(map[string]interface{}, len(extensions)) - for name, payload := range extensions { - transform, ok := transformers[name] - if !ok { - // No transformer registered, skip. - continue + + for originalKey, payload := range extensions { + transformedName := originalKey // fallback + + // Resolve registered extensions explicitly + if extDescs, err := proto.ExtensionDescs((*descriptor.MethodOptions)(nil)); err == nil { + for _, extDesc := range extDescs { + protoKey := fmt.Sprintf(".google.protobuf.MethodOptions.%s", extDesc.Name[strings.LastIndex(extDesc.Name, ".")+1:]) + if protoKey == originalKey { + transformedName = extDesc.Name + break + } + } } - transformedPayload := transform(payload) - if transformedPayload == nil { - // Transformer returned nothing, skip. - continue + + // Apply the transformer if one exists + transform, ok := transformers[transformedName] + if ok { + out[transformedName] = transform(payload) + } else if defaultTransformer != nil { + out[transformedName] = defaultTransformer(payload) + } else { + out[transformedName] = payload } - out[name] = transformedPayload } + return out } diff --git a/extensions/extensions_register.go b/extensions/extensions_register.go new file mode 100644 index 00000000..5c777276 --- /dev/null +++ b/extensions/extensions_register.go @@ -0,0 +1,152 @@ +package extensions + +import ( + "fmt" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/pseudomuto/protokit" +) + +func RegisterAllExtensions(files []*protokit.FileDescriptor) { + for _, file := range files { + for _, ext := range file.GetExtensions() { + registerExtension(ext) + } + + for _, msg := range file.GetMessages() { + registerMessageExtensions(msg) + } + } +} + +func registerMessageExtensions(msg *protokit.Descriptor) { + for _, ext := range msg.GetExtensions() { + registerExtension(ext) + } + for _, nestedMsg := range msg.GetMessages() { + registerMessageExtensions(nestedMsg) + } +} +func registerExtension(ext *protokit.ExtensionDescriptor) { + extendedType := determineExtendedType(ext.GetExtendee()) + extType := determineExtensionType(ext) + if extType == nil { + return + } + + correctFullName := fmt.Sprintf("%s.%s", ext.GetFile().GetPackage(), ext.GetName()) + + proto.RegisterExtension(&proto.ExtensionDesc{ + ExtendedType: extendedType, + ExtensionType: extType, + Field: int32(ext.GetNumber()), + Name: correctFullName, + Tag: generateTag(ext), + }) +} + + + + + + +// Helper functions: + +func determineExtendedType(typeName string) proto.Message { + switch typeName { + case ".google.protobuf.MethodOptions": + return (*descriptor.MethodOptions)(nil) + case ".google.protobuf.FieldOptions": + return (*descriptor.FieldOptions)(nil) + case ".google.protobuf.MessageOptions": + return (*descriptor.MessageOptions)(nil) + case ".google.protobuf.FileOptions": + return (*descriptor.FileOptions)(nil) + case ".google.protobuf.EnumOptions": + return (*descriptor.EnumOptions)(nil) + case ".google.protobuf.EnumValueOptions": + return (*descriptor.EnumValueOptions)(nil) + case ".google.protobuf.ServiceOptions": + return (*descriptor.ServiceOptions)(nil) + default: + panic(fmt.Sprintf("Unsupported extendee type: %s", typeName)) + } +} + + +func determineExtensionType(ext *protokit.ExtensionDescriptor) interface{} { + isRepeated := ext.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED + + switch ext.GetType() { + case descriptor.FieldDescriptorProto_TYPE_STRING: + if isRepeated { + return ([]string)(nil) + } + return (*string)(nil) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + if isRepeated { + return ([]bool)(nil) + } + return (*bool)(nil) + case descriptor.FieldDescriptorProto_TYPE_INT32, descriptor.FieldDescriptorProto_TYPE_SINT32: + if isRepeated { + return ([]int32)(nil) + } + return (*int32)(nil) + case descriptor.FieldDescriptorProto_TYPE_INT64, descriptor.FieldDescriptorProto_TYPE_SINT64: + if isRepeated { + return ([]int64)(nil) + } + return (*int64)(nil) + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + if isRepeated { + return ([]float32)(nil) + } + return (*float32)(nil) + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + if isRepeated { + return ([]float64)(nil) + } + return (*float64)(nil) + case descriptor.FieldDescriptorProto_TYPE_ENUM: + if isRepeated { + return ([]int32)(nil) // enums represented as repeated []int32 + } + return (*int32)(nil) // single enum represented as *int32 + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + return nil + default: + panic(fmt.Sprintf("Unsupported extension field type: %s", ext.GetType().String())) + } +} + + + +func generateTag(ext *protokit.ExtensionDescriptor) string { + var wireType string + switch ext.GetType() { + case descriptor.FieldDescriptorProto_TYPE_STRING: + wireType = "bytes" + case descriptor.FieldDescriptorProto_TYPE_BOOL, + descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64, + descriptor.FieldDescriptorProto_TYPE_ENUM: + wireType = "varint" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + wireType = "fixed32" + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + wireType = "fixed64" + default: + panic(fmt.Sprintf("Unsupported tag type for extension: %s", ext.GetType().String())) + } + + label := "opt" + if ext.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { + label = "rep" + } + + fieldName := ext.GetName() + return fmt.Sprintf("%s,%d,%s,name=%s", wireType, ext.GetNumber(), label, fieldName) +} diff --git a/extensions/generic/generic.go b/extensions/generic/generic.go new file mode 100644 index 00000000..01316d69 --- /dev/null +++ b/extensions/generic/generic.go @@ -0,0 +1,45 @@ +// Package generic provides a mechanism to register and transform any extension +// without knowing its specifics in advance +package generic + +import ( + "fmt" + "github.com/golang/protobuf/proto" + "github.com/pseudomuto/protoc-gen-doc/extensions" + "regexp" +) + +// ExtensionPattern defines a pattern for registering multiple extensions at once +type ExtensionPattern struct { + Name string + Number int32 + Type proto.Message +} + +// Initialize extension patterns - can be extended for more types +func init() { + // Register a generic transformer for all extensions + extensions.SetDefaultTransformer(identityTransformer) +} + + +// formatFieldNumber formats a field number for use in extension names +func formatFieldNumber(num int32) string { + return fmt.Sprintf("field_%d", num) +} + +// identityTransformer returns the extension value as-is +func identityTransformer(payload interface{}) interface{} { + return payload +} + +// ExtractOptionName extracts a cleaner name from a full extension name +func ExtractOptionName(fullName string) string { + // Match the base name pattern (everything before the last dot and numbers) + re := regexp.MustCompile(`(.*?)\.[\w_]+\d+$`) + matches := re.FindStringSubmatch(fullName) + if len(matches) > 1 { + return matches[1] + } + return fullName +} diff --git a/plugin.go b/plugin.go index 5fbabe6d..7b4281e5 100644 --- a/plugin.go +++ b/plugin.go @@ -7,7 +7,7 @@ import ( "path/filepath" "regexp" "strings" - + "github.com/pseudomuto/protoc-gen-doc/extensions" "github.com/golang/protobuf/proto" plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/pseudomuto/protokit" @@ -40,6 +40,10 @@ func (p *Plugin) Generate(r *plugin_go.CodeGeneratorRequest) (*plugin_go.CodeGen result := excludeUnwantedProtos(protokit.ParseCodeGenRequest(r), options.ExcludePatterns) + // Dynamically register extensions + extensions.RegisterAllExtensions(result) + + customTemplate := "" if options.TemplateFile != "" { diff --git a/template.go b/template.go index ee0d933f..d8a57924 100644 --- a/template.go +++ b/template.go @@ -6,10 +6,12 @@ import ( "sort" "strings" "unicode" - - "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/golang/protobuf/proto" "github.com/pseudomuto/protoc-gen-doc/extensions" "github.com/pseudomuto/protokit" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" + + ) // Template is a type for encapsulating all the parsed files, messages, fields, enums, services, extensions, etc. into @@ -541,22 +543,44 @@ func parseService(ps *protokit.ServiceDescriptor) *Service { return service } + + + + func parseServiceMethod(pm *protokit.MethodDescriptor) *ServiceMethod { - return &ServiceMethod{ - Name: pm.GetName(), - Description: description(pm.GetComments().String()), - RequestType: baseName(pm.GetInputType()), - RequestLongType: strings.TrimPrefix(pm.GetInputType(), "."+pm.GetPackage()+"."), - RequestFullType: strings.TrimPrefix(pm.GetInputType(), "."), - RequestStreaming: pm.GetClientStreaming(), - ResponseType: baseName(pm.GetOutputType()), - ResponseLongType: strings.TrimPrefix(pm.GetOutputType(), "."+pm.GetPackage()+"."), - ResponseFullType: strings.TrimPrefix(pm.GetOutputType(), "."), - ResponseStreaming: pm.GetServerStreaming(), - Options: mergeOptions(extractOptions(pm.GetOptions()), extensions.Transform(pm.OptionExtensions)), - } + methodOptions := extractOptions(pm.GetOptions()) + methodExtensionOptions := make(map[string]interface{}) + + methodOpts := pm.GetOptions() + + if methodOpts != nil { + // Using registered extensions to get correct names dynamically + for _, extDesc := range proto.RegisteredExtensions(methodOpts) { + if extValue, err := proto.GetExtension(methodOpts, extDesc); err == nil && extValue != nil { + methodExtensionOptions[extDesc.Name] = extValue + } + } + } + + return &ServiceMethod{ + Name: pm.GetName(), + Description: description(pm.GetComments().String()), + RequestType: baseName(pm.GetInputType()), + RequestLongType: strings.TrimPrefix(pm.GetInputType(), "."+pm.GetPackage()+"."), + RequestFullType: strings.TrimPrefix(pm.GetInputType(), "."), + RequestStreaming: pm.GetClientStreaming(), + ResponseType: baseName(pm.GetOutputType()), + ResponseLongType: strings.TrimPrefix(pm.GetOutputType(), "."+pm.GetPackage()+"."), + ResponseFullType: strings.TrimPrefix(pm.GetOutputType(), "."), + ResponseStreaming: pm.GetServerStreaming(), + Options: mergeOptions( + methodOptions, + methodExtensionOptions, + ), + } } + func baseName(name string) string { parts := strings.Split(name, ".") return parts[len(parts)-1] diff --git a/thirdparty/.DS_Store b/thirdparty/.DS_Store new file mode 100644 index 00000000..24b89a7e Binary files /dev/null and b/thirdparty/.DS_Store differ diff --git a/thirdparty/github.com/.DS_Store b/thirdparty/github.com/.DS_Store new file mode 100644 index 00000000..0cf4dc84 Binary files /dev/null and b/thirdparty/github.com/.DS_Store differ