Skip to content

Commit d179fba

Browse files
authored
Fix Trace Stream interceptors when RecvWithContext returns an error (#557)
* Fix Trace Stream interceptors when `RecvWithContext` returns an error The implementation of the Trace Stream interceptors' `RecvWithContext` handling did not take into account an error return where the message (result or payload) would be `nil`. This change changes the way `RecvWithContext` calls are intercepted so that if the `next` returns an error, the interceptor will populate the receive context using the passed in context and not even attempt to gather any OTel trace metadata from a message (which usually would not be available in the error case, but may be possible if another Goa interceptor were in the chain). I don't believe it is worthwhile to support a case where an error is returned by `next` along with a non-`nil` message, but if we did, we would need to resort to using reflection due to the current state of Go generics and `nil` checking: ```go func traceStreamRecv[Message TraceStreamStreamingRecvMessage]( ctx context.Context, info goa.InterceptorInfo, next goa.Endpoint, streamingMessage func(any) Message, ) (any, error) { msg, err := next(ctx, info.RawPayload()) propagator := otel.GetTextMapPropagator() sm := streamingMessage(msg) rc, ok := ctx.Value(traceStreamRecvContextKey).(*traceStreamRecvContext) if !ok { panic(fmt.Errorf("clue interceptors trace stream receive method called without prior setup (service: %v, method: %v)", info.Service(), info.Method())) } if reflect.ValueOf(sm).Elem().IsZero() { rc.ctx = ctx } else { rc.ctx = propagator.Extract(ctx, propagation.MapCarrier(sm.TraceMetadata())) } return msg, err } ``` * Return `nil` instead of `err` since it is unnecessary
1 parent 06e0199 commit d179fba

File tree

3 files changed

+131
-15
lines changed

3 files changed

+131
-15
lines changed

interceptors/trace_stream.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,18 @@ func traceStreamRecv[Message TraceStreamStreamingRecvMessage](
9898
streamingMessage func(any) Message,
9999
) (any, error) {
100100
msg, err := next(ctx, info.RawPayload())
101-
propagator := otel.GetTextMapPropagator()
102-
sm := streamingMessage(msg)
103101
rc, ok := ctx.Value(traceStreamRecvContextKey).(*traceStreamRecvContext)
104102
if !ok {
105103
panic(fmt.Errorf("clue interceptors trace stream receive method called without prior setup (service: %v, method: %v)", info.Service(), info.Method()))
106104
}
105+
if err != nil {
106+
rc.ctx = ctx
107+
return nil, err
108+
}
109+
propagator := otel.GetTextMapPropagator()
110+
sm := streamingMessage(msg)
107111
rc.ctx = propagator.Extract(ctx, propagation.MapCarrier(sm.TraceMetadata()))
108-
return msg, err
112+
return msg, nil
109113
}
110114

111115
// traceStreamWrapRecvAndReturnContext is a helper function for wrapped trace stream receive methods

interceptors/trace_stream_client_test.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ func TestTraceBidirectionalStreamClientInterceptor(t *testing.T) {
8282
info.addRawPayload(func() any {
8383
return nil
8484
})
85-
info.addClientStreamingResult(func(res any) *mockTraceStreamingRecvMessage {
86-
return newMockTraceStreamingRecvMessage(assert)
87-
})
8885
info.addService(func() string {
8986
return "TestService"
9087
})
@@ -157,6 +154,37 @@ func TestTraceBidirectionalStreamClientInterceptor(t *testing.T) {
157154
assert.False(payload.hasMore(), "missing expected payload calls")
158155
})
159156

157+
t.Run("receive with error", func(t *testing.T) {
158+
var (
159+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
160+
info = newMockTraceStreamInfo(assert.New(t))
161+
interceptor = &TraceBidirectionalStreamClientInterceptor[*mockTraceStreamInfo, *mockTraceStreamingSendMessage, *mockTraceStreamingRecvMessage]{}
162+
nextCalled = false
163+
next = func(ctx context.Context, _ any) (any, error) {
164+
nextCalled = true
165+
return nil, assert.AnError
166+
}
167+
)
168+
info.addCallType(func() goa.InterceptorCallType {
169+
return goa.InterceptorStreamingRecv
170+
})
171+
info.addRawPayload(func() any {
172+
return nil
173+
})
174+
175+
ctx = SetupTraceStreamRecvContext(ctx)
176+
res, err := interceptor.TraceBidirectionalStream(ctx, info, next)
177+
assert.ErrorIs(t, err, assert.AnError)
178+
assert.Nil(t, res)
179+
180+
assert.NotPanics(t, func() {
181+
ctx = GetTraceStreamRecvContext(ctx)
182+
})
183+
184+
assert.True(t, nextCalled, "missing expected next call")
185+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
186+
})
187+
160188
t.Run("unary", func(t *testing.T) {
161189
var (
162190
assert = assert.New(t)
@@ -205,9 +233,6 @@ func TestTraceServerToClientStreamClientInterceptor(t *testing.T) {
205233
info.addRawPayload(func() any {
206234
return nil
207235
})
208-
info.addClientStreamingResult(func(res any) *mockTraceStreamingRecvMessage {
209-
return newMockTraceStreamingRecvMessage(assert)
210-
})
211236
info.addService(func() string {
212237
return "TestService"
213238
})
@@ -280,6 +305,37 @@ func TestTraceServerToClientStreamClientInterceptor(t *testing.T) {
280305
assert.False(payload.hasMore(), "missing expected payload calls")
281306
})
282307

308+
t.Run("receive with error", func(t *testing.T) {
309+
var (
310+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
311+
info = newMockTraceStreamInfo(assert.New(t))
312+
interceptor = &TraceServerToClientStreamClientInterceptor[*mockTraceStreamInfo, *mockTraceStreamingRecvMessage]{}
313+
nextCalled = false
314+
next = func(ctx context.Context, _ any) (any, error) {
315+
nextCalled = true
316+
return nil, assert.AnError
317+
}
318+
)
319+
info.addCallType(func() goa.InterceptorCallType {
320+
return goa.InterceptorStreamingRecv
321+
})
322+
info.addRawPayload(func() any {
323+
return nil
324+
})
325+
326+
ctx = SetupTraceStreamRecvContext(ctx)
327+
res, err := interceptor.TraceServerToClientStream(ctx, info, next)
328+
assert.ErrorIs(t, err, assert.AnError)
329+
assert.Nil(t, res)
330+
331+
assert.NotPanics(t, func() {
332+
ctx = GetTraceStreamRecvContext(ctx)
333+
})
334+
335+
assert.True(t, nextCalled, "missing expected next call")
336+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
337+
})
338+
283339
t.Run("unary", func(t *testing.T) {
284340
var (
285341
assert = assert.New(t)

interceptors/trace_stream_server_test.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ func TestTraceBidirectionalStreamServerInterceptor(t *testing.T) {
8282
info.addRawPayload(func() any {
8383
return nil
8484
})
85-
info.addServerStreamingPayload(func(pay any) *mockTraceStreamingRecvMessage {
86-
return newMockTraceStreamingRecvMessage(assert)
87-
})
8885
info.addService(func() string {
8986
return "TestService"
9087
})
@@ -157,6 +154,37 @@ func TestTraceBidirectionalStreamServerInterceptor(t *testing.T) {
157154
assert.False(payload.hasMore(), "missing expected payload calls")
158155
})
159156

157+
t.Run("receive with error", func(t *testing.T) {
158+
var (
159+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
160+
info = newMockTraceStreamInfo(assert.New(t))
161+
interceptor = &TraceBidirectionalStreamServerInterceptor[*mockTraceStreamInfo, *mockTraceStreamingRecvMessage, *mockTraceStreamingSendMessage]{}
162+
nextCalled = false
163+
next = func(ctx context.Context, payload any) (any, error) {
164+
nextCalled = true
165+
return nil, assert.AnError
166+
}
167+
)
168+
info.addCallType(func() goa.InterceptorCallType {
169+
return goa.InterceptorStreamingRecv
170+
})
171+
info.addRawPayload(func() any {
172+
return nil
173+
})
174+
175+
ctx = SetupTraceStreamRecvContext(ctx)
176+
res, err := interceptor.TraceBidirectionalStream(ctx, info, next)
177+
assert.ErrorIs(t, err, assert.AnError)
178+
assert.Nil(t, res)
179+
180+
assert.NotPanics(t, func() {
181+
ctx = GetTraceStreamRecvContext(ctx)
182+
})
183+
184+
assert.True(t, nextCalled, "missing expected next call")
185+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
186+
})
187+
160188
t.Run("unary", func(t *testing.T) {
161189
var (
162190
assert = assert.New(t)
@@ -285,9 +313,6 @@ func TestTraceClientToServerStreamServerInterceptor(t *testing.T) {
285313
info.addRawPayload(func() any {
286314
return nil
287315
})
288-
info.addServerStreamingPayload(func(pay any) *mockTraceStreamingRecvMessage {
289-
return newMockTraceStreamingRecvMessage(assert)
290-
})
291316
info.addService(func() string {
292317
return "TestService"
293318
})
@@ -360,6 +385,37 @@ func TestTraceClientToServerStreamServerInterceptor(t *testing.T) {
360385
assert.False(payload.hasMore(), "missing expected payload calls")
361386
})
362387

388+
t.Run("receive with error", func(t *testing.T) {
389+
var (
390+
ctx = log.Context(context.Background(), log.WithFormat(log.FormatText))
391+
info = newMockTraceStreamInfo(assert.New(t))
392+
interceptor = &TraceClientToServerStreamServerInterceptor[*mockTraceStreamInfo, *mockTraceStreamingRecvMessage]{}
393+
nextCalled = false
394+
next = func(ctx context.Context, payload any) (any, error) {
395+
nextCalled = true
396+
return nil, assert.AnError
397+
}
398+
)
399+
info.addCallType(func() goa.InterceptorCallType {
400+
return goa.InterceptorStreamingRecv
401+
})
402+
info.addRawPayload(func() any {
403+
return nil
404+
})
405+
406+
ctx = SetupTraceStreamRecvContext(ctx)
407+
res, err := interceptor.TraceClientToServerStream(ctx, info, next)
408+
assert.ErrorIs(t, err, assert.AnError)
409+
assert.Nil(t, res)
410+
411+
assert.NotPanics(t, func() {
412+
ctx = GetTraceStreamRecvContext(ctx)
413+
})
414+
415+
assert.True(t, nextCalled, "missing expected next call")
416+
assert.False(t, info.hasMore(), "missing expected interceptor info calls")
417+
})
418+
363419
t.Run("unary", func(t *testing.T) {
364420
var (
365421
assert = assert.New(t)

0 commit comments

Comments
 (0)