Skip to content

Commit 48f9e9d

Browse files
authored
Merge branch 'main' into no-error-on-unexported-methods
2 parents a589578 + 2dc6feb commit 48f9e9d

File tree

10 files changed

+270
-33
lines changed

10 files changed

+270
-33
lines changed

interceptors/trace_stream.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,18 @@ func traceStreamRecv[Message TraceStreamStreamingRecvMessage](
9898
streamingMessage func(any) Message,
9999
) (any, error) {
100100
msg, err := next(ctx, info.RawPayload())
101-
propagator := otel.GetTextMapPropagator()
102-
sm := streamingMessage(msg)
103101
rc, ok := ctx.Value(traceStreamRecvContextKey).(*traceStreamRecvContext)
104102
if !ok {
105103
panic(fmt.Errorf("clue interceptors trace stream receive method called without prior setup (service: %v, method: %v)", info.Service(), info.Method()))
106104
}
105+
if err != nil {
106+
rc.ctx = ctx
107+
return nil, err
108+
}
109+
propagator := otel.GetTextMapPropagator()
110+
sm := streamingMessage(msg)
107111
rc.ctx = propagator.Extract(ctx, propagation.MapCarrier(sm.TraceMetadata()))
108-
return msg, err
112+
return msg, nil
109113
}
110114

111115
// traceStreamWrapRecvAndReturnContext is a helper function for wrapped trace stream receive methods

interceptors/trace_stream_client_test.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ func TestTraceBidirectionalStreamClientInterceptor(t *testing.T) {
8282
info.addRawPayload(func() any {
8383
return nil
8484
})
85-
info.addClientStreamingResult(func(res any) *mockTraceStreamingRecvMessage {
86-
return newMockTraceStreamingRecvMessage(assert)
87-
})
8885
info.addService(func() string {
8986
return "TestService"
9087
})
@@ -157,6 +154,37 @@ func TestTraceBidirectionalStreamClientInterceptor(t *testing.T) {
157154
assert.False(payload.hasMore(), "missing expected payload calls")
158155
})
159156

157+
t.Run("receive with error", func(t *testing.T) {
158+
var (
159+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
160+
info = newMockTraceStreamInfo(assert.New(t))
161+
interceptor = &TraceBidirectionalStreamClientInterceptor[*mockTraceStreamInfo, *mockTraceStreamingSendMessage, *mockTraceStreamingRecvMessage]{}
162+
nextCalled = false
163+
next = func(ctx context.Context, _ any) (any, error) {
164+
nextCalled = true
165+
return nil, assert.AnError
166+
}
167+
)
168+
info.addCallType(func() goa.InterceptorCallType {
169+
return goa.InterceptorStreamingRecv
170+
})
171+
info.addRawPayload(func() any {
172+
return nil
173+
})
174+
175+
ctx = SetupTraceStreamRecvContext(ctx)
176+
res, err := interceptor.TraceBidirectionalStream(ctx, info, next)
177+
assert.ErrorIs(t, err, assert.AnError)
178+
assert.Nil(t, res)
179+
180+
assert.NotPanics(t, func() {
181+
ctx = GetTraceStreamRecvContext(ctx)
182+
})
183+
184+
assert.True(t, nextCalled, "missing expected next call")
185+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
186+
})
187+
160188
t.Run("unary", func(t *testing.T) {
161189
var (
162190
assert = assert.New(t)
@@ -205,9 +233,6 @@ func TestTraceServerToClientStreamClientInterceptor(t *testing.T) {
205233
info.addRawPayload(func() any {
206234
return nil
207235
})
208-
info.addClientStreamingResult(func(res any) *mockTraceStreamingRecvMessage {
209-
return newMockTraceStreamingRecvMessage(assert)
210-
})
211236
info.addService(func() string {
212237
return "TestService"
213238
})
@@ -280,6 +305,37 @@ func TestTraceServerToClientStreamClientInterceptor(t *testing.T) {
280305
assert.False(payload.hasMore(), "missing expected payload calls")
281306
})
282307

308+
t.Run("receive with error", func(t *testing.T) {
309+
var (
310+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
311+
info = newMockTraceStreamInfo(assert.New(t))
312+
interceptor = &TraceServerToClientStreamClientInterceptor[*mockTraceStreamInfo, *mockTraceStreamingRecvMessage]{}
313+
nextCalled = false
314+
next = func(ctx context.Context, _ any) (any, error) {
315+
nextCalled = true
316+
return nil, assert.AnError
317+
}
318+
)
319+
info.addCallType(func() goa.InterceptorCallType {
320+
return goa.InterceptorStreamingRecv
321+
})
322+
info.addRawPayload(func() any {
323+
return nil
324+
})
325+
326+
ctx = SetupTraceStreamRecvContext(ctx)
327+
res, err := interceptor.TraceServerToClientStream(ctx, info, next)
328+
assert.ErrorIs(t, err, assert.AnError)
329+
assert.Nil(t, res)
330+
331+
assert.NotPanics(t, func() {
332+
ctx = GetTraceStreamRecvContext(ctx)
333+
})
334+
335+
assert.True(t, nextCalled, "missing expected next call")
336+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
337+
})
338+
283339
t.Run("unary", func(t *testing.T) {
284340
var (
285341
assert = assert.New(t)

interceptors/trace_stream_server_test.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ func TestTraceBidirectionalStreamServerInterceptor(t *testing.T) {
8282
info.addRawPayload(func() any {
8383
return nil
8484
})
85-
info.addServerStreamingPayload(func(pay any) *mockTraceStreamingRecvMessage {
86-
return newMockTraceStreamingRecvMessage(assert)
87-
})
8885
info.addService(func() string {
8986
return "TestService"
9087
})
@@ -157,6 +154,37 @@ func TestTraceBidirectionalStreamServerInterceptor(t *testing.T) {
157154
assert.False(payload.hasMore(), "missing expected payload calls")
158155
})
159156

157+
t.Run("receive with error", func(t *testing.T) {
158+
var (
159+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
160+
info = newMockTraceStreamInfo(assert.New(t))
161+
interceptor = &TraceBidirectionalStreamServerInterceptor[*mockTraceStreamInfo, *mockTraceStreamingRecvMessage, *mockTraceStreamingSendMessage]{}
162+
nextCalled = false
163+
next = func(ctx context.Context, payload any) (any, error) {
164+
nextCalled = true
165+
return nil, assert.AnError
166+
}
167+
)
168+
info.addCallType(func() goa.InterceptorCallType {
169+
return goa.InterceptorStreamingRecv
170+
})
171+
info.addRawPayload(func() any {
172+
return nil
173+
})
174+
175+
ctx = SetupTraceStreamRecvContext(ctx)
176+
res, err := interceptor.TraceBidirectionalStream(ctx, info, next)
177+
assert.ErrorIs(t, err, assert.AnError)
178+
assert.Nil(t, res)
179+
180+
assert.NotPanics(t, func() {
181+
ctx = GetTraceStreamRecvContext(ctx)
182+
})
183+
184+
assert.True(t, nextCalled, "missing expected next call")
185+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
186+
})
187+
160188
t.Run("unary", func(t *testing.T) {
161189
var (
162190
assert = assert.New(t)
@@ -285,9 +313,6 @@ func TestTraceClientToServerStreamServerInterceptor(t *testing.T) {
285313
info.addRawPayload(func() any {
286314
return nil
287315
})
288-
info.addServerStreamingPayload(func(pay any) *mockTraceStreamingRecvMessage {
289-
return newMockTraceStreamingRecvMessage(assert)
290-
})
291316
info.addService(func() string {
292317
return "TestService"
293318
})
@@ -360,6 +385,37 @@ func TestTraceClientToServerStreamServerInterceptor(t *testing.T) {
360385
assert.False(payload.hasMore(), "missing expected payload calls")
361386
})
362387

388+
t.Run("receive with error", func(t *testing.T) {
389+
var (
390+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
391+
info = newMockTraceStreamInfo(assert.New(t))
392+
interceptor = &TraceClientToServerStreamServerInterceptor[*mockTraceStreamInfo, *mockTraceStreamingRecvMessage]{}
393+
nextCalled = false
394+
next = func(ctx context.Context, payload any) (any, error) {
395+
nextCalled = true
396+
return nil, assert.AnError
397+
}
398+
)
399+
info.addCallType(func() goa.InterceptorCallType {
400+
return goa.InterceptorStreamingRecv
401+
})
402+
info.addRawPayload(func() any {
403+
return nil
404+
})
405+
406+
ctx = SetupTraceStreamRecvContext(ctx)
407+
res, err := interceptor.TraceClientToServerStream(ctx, info, next)
408+
assert.ErrorIs(t, err, assert.AnError)
409+
assert.Nil(t, res)
410+
411+
assert.NotPanics(t, func() {
412+
ctx = GetTraceStreamRecvContext(ctx)
413+
})
414+
415+
assert.True(t, nextCalled, "missing expected next call")
416+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
417+
})
418+
363419
t.Run("unary", func(t *testing.T) {
364420
var (
365421
assert = assert.New(t)

mock/cmd/cmg/main.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ import (
1212

1313
func main() {
1414
var (
15-
gSet = flag.NewFlagSet("global", flag.ExitOnError)
16-
debug, help, h *bool
17-
addGlobals = func(set *flag.FlagSet) {
15+
gSet = flag.NewFlagSet("global", flag.ExitOnError)
16+
debug, testify, help, h *bool
17+
addGlobals = func(set *flag.FlagSet) {
1818
debug = set.Bool("debug", false, "Print debug output")
19+
testify = set.Bool("testify", false, "Use github.com/stretchr/testify for assertions")
1920
help = set.Bool("help", false, "Print help information")
2021
h = set.Bool("h", false, "Print help information")
2122
}
@@ -67,7 +68,7 @@ func main() {
6768
} else {
6869
ctx = log.Context(ctx, log.WithDisableBuffering(func(ctx context.Context) bool { return true }))
6970
}
70-
err := cluemockgen.Generate(ctx, args, "")
71+
err := cluemockgen.Generate(ctx, args, "", *testify)
7172
if err != nil {
7273
os.Exit(1)
7374
}

mock/cmd/cmg/pkg/generate.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ import (
1313
"goa.design/clue/mock/cmd/cmg/pkg/parse"
1414
)
1515

16-
func Generate(ctx context.Context, patterns []string, dir string) error {
16+
// Generate generates the mocks for the given patterns and directory.
17+
// If testify is true, it uses github.com/stretchr/testify for assertions.
18+
func Generate(ctx context.Context, patterns []string, dir string, testify bool) error {
1719
ps, err := parse.LoadPackages(patterns, dir)
1820
if err != nil {
1921
log.Error(ctx, err)
@@ -23,7 +25,7 @@ func Generate(ctx context.Context, patterns []string, dir string) error {
2325
var errs []error
2426

2527
for _, p := range ps {
26-
err = generatePackage(ctx, p)
28+
err = generatePackage(ctx, p, testify)
2729
if err != nil {
2830
errs = append(errs, err)
2931
}
@@ -36,9 +38,11 @@ func Generate(ctx context.Context, patterns []string, dir string) error {
3638
return nil
3739
}
3840

39-
func generatePackage(ctx context.Context, p parse.Package) error {
41+
// generatePackage generates the mocks for the given package.
42+
// If testify is true, it uses github.com/stretchr/testify for assertions.
43+
func generatePackage(ctx context.Context, p parse.Package, testify bool) error {
4044
ctx = log.With(ctx, log.KV{K: "pkg name", V: p.Name()})
41-
log.Print(ctx, log.KV{K: "pkg path", V: p.PkgPath()})
45+
log.Print(ctx, log.KV{K: "pkg path", V: p.PkgPath()}, log.KV{K: "testify", V: testify})
4246

4347
is, err := p.Interfaces()
4448
if err != nil {
@@ -72,7 +76,7 @@ func generatePackage(ctx context.Context, p parse.Package) error {
7276
}
7377
}
7478
for file, interfaces := range interfacesByFile {
75-
err = generateFile(ctx, p, file, interfaces)
79+
err = generateFile(ctx, p, file, interfaces, testify)
7680
if err != nil {
7781
return err
7882
}
@@ -81,13 +85,15 @@ func generatePackage(ctx context.Context, p parse.Package) error {
8185
return nil
8286
}
8387

84-
func generateFile(ctx context.Context, p parse.Package, file string, interfaces []parse.Interface) error {
88+
// generateFile generates the mocks for the given file.
89+
// If testify is true, it uses github.com/stretchr/testify for assertions.
90+
func generateFile(ctx context.Context, p parse.Package, file string, interfaces []parse.Interface, testify bool) error {
8591
ctx = log.With(ctx, log.KV{K: "file", V: file})
8692
interfaceNames := make([]string, len(interfaces))
8793
for j, i := range interfaces {
8894
interfaceNames[j] = i.Name()
8995
}
90-
log.Print(ctx, log.KV{K: "interface names", V: interfaceNames})
96+
log.Print(ctx, log.KV{K: "interface names", V: interfaceNames}, log.KV{K: "testify", V: testify})
9197

9298
dir, baseFile := filepath.Split(file)
9399
mocksDir := filepath.Join(dir, "mocks")
@@ -115,7 +121,7 @@ func generateFile(ctx context.Context, p parse.Package, file string, interfaces
115121
}
116122
}()
117123

118-
mocks := generate.NewMocks("mock", p, interfaces, Version)
124+
mocks := generate.NewMocks("mock", p, interfaces, Version, testify)
119125
if err := mocks.Render(f); err != nil {
120126
log.Error(ctx, err)
121127
return err

mock/cmd/cmg/pkg/generate/_tests/testify/mocks/testify.go

Lines changed: 52 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package testify
2+
3+
type (
4+
Testify interface {
5+
Simple(a, b int) bool
6+
}
7+
)

0 commit comments

Comments
 (0)