From 38329888a3d17da96dd32de22b3384d17946ea0e Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Fri, 13 Feb 2026 13:37:16 +0800 Subject: [PATCH] optimize: use context.WithCancelCause and context.AfterFunc to simplify streaming lifecycle control --- go.mod | 2 +- go.sum | 1 + .../grpc => internal/stream}/context.go | 37 +++-------- internal/stream/context_test.go | 56 +++++++++++++++++ pkg/remote/trans/nphttp2/grpc/context_test.go | 43 ------------- pkg/remote/trans/nphttp2/grpc/http2_client.go | 63 ++++++------------- pkg/remote/trans/nphttp2/grpc/http2_server.go | 17 ++++- pkg/remote/trans/nphttp2/grpc/trace_test.go | 3 + pkg/remote/trans/nphttp2/grpc/transport.go | 20 +++--- .../trans/nphttp2/grpc/transport_test.go | 11 +--- pkg/remote/trans/ttstream/client_handler.go | 6 ++ .../trans/ttstream/client_stream_cleanup.go | 33 ---------- pkg/remote/trans/ttstream/context.go | 54 ---------------- pkg/remote/trans/ttstream/stream_client.go | 9 +++ .../trans/ttstream/stream_client_test.go | 9 +-- pkg/remote/trans/ttstream/stream_server.go | 2 +- .../trans/ttstream/stream_server_test.go | 5 +- pkg/remote/trans/ttstream/transport_client.go | 16 ----- pkg/remote/trans/ttstream/transport_server.go | 7 ++- pkg/remote/trans/ttstream/transport_test.go | 11 ++-- 20 files changed, 146 insertions(+), 259 deletions(-) rename {pkg/remote/trans/nphttp2/grpc => internal/stream}/context.go (50%) create mode 100644 internal/stream/context_test.go delete mode 100644 pkg/remote/trans/nphttp2/grpc/context_test.go delete mode 100644 pkg/remote/trans/ttstream/client_stream_cleanup.go delete mode 100644 pkg/remote/trans/ttstream/context.go diff --git a/go.mod b/go.mod index 1a0a009183..b71e706363 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/cloudwego/kitex -go 1.20 +go 1.21 require ( github.com/bytedance/gopkg v0.1.3 diff --git a/go.sum b/go.sum index 471163aaea..06b2eada07 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,7 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= diff --git a/pkg/remote/trans/nphttp2/grpc/context.go b/internal/stream/context.go similarity index 50% rename from pkg/remote/trans/nphttp2/grpc/context.go rename to internal/stream/context.go index e35cdb653a..c9d2070dda 100644 --- a/pkg/remote/trans/nphttp2/grpc/context.go +++ b/internal/stream/context.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,41 +14,22 @@ * limitations under the License. */ -package grpc +package stream -import ( - "context" - "sync/atomic" -) +import "context" // contextWithCancelReason implements context.Context -// with a cancel func for passing cancel reason -// NOTE: use context.WithCancelCause when go1.20? +// with Err() retrieving cause err with context.Cause automatically. +// Whether using gRPC or ttstream, the ctx.Err() returns protocol-specific errors rather than context.Canceled or context.DeadlineExceeded. +// When using context.WithCancelCause, an additional layer of encapsulation is still required to avoid breaking changes. type contextWithCancelReason struct { context.Context - - cancel context.CancelFunc - reason atomic.Value } func (c *contextWithCancelReason) Err() error { - err := c.reason.Load() - if err != nil { - return err.(error) - } - return c.Context.Err() -} - -func (c *contextWithCancelReason) CancelWithReason(reason error) { - if reason != nil { - c.reason.CompareAndSwap(nil, reason) - } - c.cancel() + return context.Cause(c.Context) } -type cancelWithReason func(reason error) - -func newContextWithCancelReason(ctx context.Context, cancel context.CancelFunc) (context.Context, cancelWithReason) { - ret := &contextWithCancelReason{Context: ctx, cancel: cancel} - return ret, ret.CancelWithReason +func NewContextWithCancelReason(ctx context.Context) context.Context { + return &contextWithCancelReason{Context: ctx} } diff --git a/internal/stream/context_test.go b/internal/stream/context_test.go new file mode 100644 index 0000000000..c4d04c6128 --- /dev/null +++ b/internal/stream/context_test.go @@ -0,0 +1,56 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stream + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +func Test_contextWithCancelReason(t *testing.T) { + ctx := context.Background() + newCtx := NewContextWithCancelReason(ctx) + test.Assert(t, newCtx.Err() == nil, newCtx.Err()) + + ctx, cancel := context.WithCancel(context.Background()) + newCtx = NewContextWithCancelReason(ctx) + cancel() + test.Assert(t, newCtx.Err() == context.Canceled, newCtx.Err()) + + ctx, cancel = context.WithTimeout(context.Background(), 20*time.Millisecond) + newCtx = NewContextWithCancelReason(ctx) + time.Sleep(50 * time.Millisecond) + cancel() + test.Assert(t, newCtx.Err() == context.DeadlineExceeded, newCtx.Err()) + + ctx, cancelCause := context.WithCancelCause(context.Background()) + newCtx = NewContextWithCancelReason(ctx) + err := errors.New("test") + cancelCause(err) + test.Assert(t, newCtx.Err() == err, newCtx.Err()) + + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancelCause = context.WithCancelCause(ctx) + newCtx = NewContextWithCancelReason(ctx) + cancelCause(err) + cancel() + test.Assert(t, newCtx.Err() == err, newCtx.Err()) +} diff --git a/pkg/remote/trans/nphttp2/grpc/context_test.go b/pkg/remote/trans/nphttp2/grpc/context_test.go deleted file mode 100644 index d731c781fc..0000000000 --- a/pkg/remote/trans/nphttp2/grpc/context_test.go +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package grpc - -import ( - "context" - "errors" - "testing" - - "github.com/cloudwego/kitex/internal/test" -) - -func TestContextWithCancelReason(t *testing.T) { - ctx0, cancel0 := context.WithCancel(context.Background()) - ctx, cancel := newContextWithCancelReason(ctx0, cancel0) - - // cancel contextWithCancelReason - expectErr := errors.New("testing") - cancel(expectErr) - test.Assert(t, ctx0.Err() == context.Canceled) - test.Assert(t, ctx.Err() == expectErr) - - // cancel underlying context - ctx0, cancel0 = context.WithCancel(context.Background()) - ctx, _ = newContextWithCancelReason(ctx0, cancel0) - cancel0() - test.Assert(t, ctx0.Err() == context.Canceled) - test.Assert(t, ctx.Err() == context.Canceled) -} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 435f122a5b..fa64fffdac 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -48,15 +48,8 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/peer" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/utils" ) -// ticker is used to manage closeStreamTask. -// it triggers and cleans up actively cancelled streams every 5s. -// Streaming QPS is generally not too high, if there is a requirement for timeliness, then consider making it configurable. -// To reduce the overhead of goroutines in a multi-connection scenario, use the Sync SharedTicker -var ticker = utils.NewSyncSharedTicker(5 * time.Second) - // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. @@ -213,13 +206,6 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, } } t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst) - task := &closeStreamTask{t: t} - t.onClose = func() { - onClose() - // when http2Client has been closed, remove this task - ticker.Delete(task) - } - ticker.Add(task) // Start the reader goroutine for incoming message. Each transport has // a dedicated goroutine which reads HTTP2 frame from network. Then it @@ -292,34 +278,6 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, return t, nil } -// closeStreamTask is used to clean up streams that have been actively cancelled by users -type closeStreamTask struct { - t *http2Client - toCloseStreams []*Stream -} - -func (task *closeStreamTask) Tick() { - trans := task.t - trans.mu.Lock() - for _, stream := range trans.activeStreams { - select { - // judge whether stream has been canceled - case <-stream.Context().Done(): - task.toCloseStreams = append(task.toCloseStreams, stream) - default: - } - } - trans.mu.Unlock() - - for i, stream := range task.toCloseStreams { - // uniformly converted to status error - sErr := ContextErr(stream.Context().Err()) - trans.closeStream(stream, sErr, true, http2.ErrCodeCancel, status.Convert(sErr), nil, false) - task.toCloseStreams[i] = nil - } - task.toCloseStreams = task.toCloseStreams[:0] -} - type clientTransportDump struct { LocalAddress string `json:"local_address"` State transportState `json:"transport_state"` @@ -567,12 +525,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } return true } + stop := context.AfterFunc(s.ctx, func() { + sErr := ContextErr(s.ctx.Err()) + t.closeStream(s, sErr, true, http2.ErrCodeCancel, status.Convert(sErr), nil, false) + }) + s.ctxCleanUp = stop + defer func() { + // If exiting abnormally, execute stop to prevent leak + if err != nil { + stop() + } + }() for { - success, err := t.controlBuf.executeAndPut(func(it interface{}) bool { + success, eErr := t.controlBuf.executeAndPut(func(it interface{}) bool { return checkForHeaderListSize(it) && checkForStreamQuota(it) }, hdr) - if err != nil { - return nil, err + if eErr != nil { + return nil, eErr } if success { break @@ -677,6 +646,10 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. t.controlBuf.executeAndPut(addBackStreamQuota, cleanup) // This will unblock write. close(s.done) + // invoke stop func of ctx.AfterFunc to avoid leak + if s.ctxCleanUp != nil { + s.ctxCleanUp() + } } // Close kicks off the shutdown process of the transport. This should be called diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index 7e457682cb..3e06bfe37f 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -38,6 +38,7 @@ import ( "golang.org/x/net/http2/hpack" "google.golang.org/protobuf/proto" + internal_stream "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" "github.com/cloudwego/kitex/pkg/remote/transmeta" @@ -343,13 +344,23 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f // s is just created by the caller. No lock needed. s.state = streamReadDone } - var cancel context.CancelFunc + if state.data.timeoutSet { + var cancel context.CancelFunc + var cancelCause context.CancelCauseFunc s.ctx, cancel = context.WithTimeout(t.ctx, state.data.timeout) + s.ctx, cancelCause = context.WithCancelCause(s.ctx) + s.cancel = func(cause error) { + cancelCause(cause) + cancel() + } } else { - s.ctx, cancel = context.WithCancel(t.ctx) + s.ctx, s.cancel = context.WithCancelCause(t.ctx) } - s.ctx, s.cancel = newContextWithCancelReason(s.ctx, cancel) + // The semantics of `ctx.Err()` have changed. It now returns gRPC internal status error. + // If we just use context.WithCancelCause, users must use context.Cause to retrieve the previous error. This results a breaking change. + // Therefore, we need to encapsulate a Context that automatically executes `context.Cause` when `ctx.Err` is called. + s.ctx = internal_stream.NewContextWithCancelReason(s.ctx) // Attach the received metadata to the context. if len(state.data.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) diff --git a/pkg/remote/trans/nphttp2/grpc/trace_test.go b/pkg/remote/trans/nphttp2/grpc/trace_test.go index ad78384752..fafb01fd2b 100644 --- a/pkg/remote/trans/nphttp2/grpc/trace_test.go +++ b/pkg/remote/trans/nphttp2/grpc/trace_test.go @@ -23,6 +23,7 @@ import ( "math" "sync" "testing" + "time" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -622,6 +623,7 @@ func Test_trace(t *testing.T) { req := []byte("hello") err = cli.Write(s, nil, req, &Options{}) test.Assert(t, err == nil, err) + time.Sleep(50 * time.Millisecond) cancelFunc() <-srv.srvReady @@ -643,6 +645,7 @@ func Test_trace(t *testing.T) { test.Assert(t, err == nil, err) err = cli.Write(s, nil, nil, &Options{Last: true}) test.Assert(t, err == nil, err) + time.Sleep(50 * time.Millisecond) cancelFunc() <-srv.srvReady diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index c5d37bd821..0ad271fbbf 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -34,6 +34,7 @@ import ( "sync" "sync/atomic" + internal_stream "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" @@ -234,13 +235,14 @@ const ( // Stream represents an RPC in the transport layer. type Stream struct { id uint32 - st ServerTransport // nil for client side Stream - ct *http2Client // nil for server side Stream - ctx context.Context // the associated context of the stream - cancel cancelWithReason // always nil for client side Stream - done chan struct{} // closed at the end of stream to unblock writers. On the client side. - ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) - method string // the associated RPC method of the stream + st ServerTransport // nil for client side Stream + ct *http2Client // nil for server side Stream + ctx context.Context // the associated context of the stream + cancel context.CancelCauseFunc + done chan struct{} // closed at the end of stream to unblock writers. On the client side. + ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) + ctxCleanUp func() bool // the stop func of context.AfterFunc, nil for server side Stream + method string // the associated RPC method of the stream recvCompress string sendCompress string buf *recvBuffer @@ -555,8 +557,8 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho hdrMu: sync.Mutex{}, } - ctx, cancel := context.WithCancel(ctx) - stream.ctx, stream.cancel = newContextWithCancelReason(ctx, cancel) + stream.ctx, stream.cancel = context.WithCancelCause(ctx) + stream.ctx = internal_stream.NewContextWithCancelReason(stream.ctx) return stream } diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 57e710e079..d886c87978 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -48,7 +48,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/testutils" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" - "github.com/cloudwego/kitex/pkg/utils" ) type server struct { @@ -371,7 +370,7 @@ func (h *testStreamHandler) handleStreamCancel(t *testing.T, s *Stream) { func verifyCancelError(t *testing.T, err error) { test.Assert(t, err != nil, err) st, ok := status.FromError(err) - test.Assert(t, ok) + test.Assert(t, ok, err) test.Assert(t, st.Code() == codes.Canceled, st.Code()) test.Assert(t, strings.Contains(st.Message(), "transport: RSTStream Frame received with error code"), st.Message()) } @@ -855,14 +854,6 @@ func setUpWithOnGoAway(t *testing.T, port int, serverConfig *ServerConfig, ht hT return server, ct.(*http2Client) } -func TestMain(m *testing.M) { - // set the ticker to make tests running fast - oldTicker := ticker - ticker = utils.NewSyncSharedTicker(10 * time.Millisecond) - m.Run() - ticker = oldTicker -} - // TestInflightStreamClosing ensures that closing in-flight stream // sends status error to concurrent stream reader. func TestInflightStreamClosing(t *testing.T) { diff --git a/pkg/remote/trans/ttstream/client_handler.go b/pkg/remote/trans/ttstream/client_handler.go index 2b63c55bb7..96bc98aaa0 100644 --- a/pkg/remote/trans/ttstream/client_handler.go +++ b/pkg/remote/trans/ttstream/client_handler.go @@ -87,8 +87,14 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( cs.setRecvTimeoutConfig(rconfig) cs.setMetaFrameHandler(c.metaHandler) cs.setTraceController(c.traceCtl) + stop := context.AfterFunc(cs.ctx, func() { + cs.ctxDoneCallback(cs.ctx) + }) + cs.setCtxCleanup(stop) if err = trans.WriteStream(ctx, cs, intHeader, strHeader); err != nil { + // execute stop to prevent leak + stop() return nil, err } diff --git a/pkg/remote/trans/ttstream/client_stream_cleanup.go b/pkg/remote/trans/ttstream/client_stream_cleanup.go deleted file mode 100644 index 1530e14bd5..0000000000 --- a/pkg/remote/trans/ttstream/client_stream_cleanup.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -func (t *clientTransport) Tick() { - var toCloseStreams []*clientStream - t.streams.Range(func(key, value any) bool { - s := value.(*clientStream) - select { - case <-s.ctx.Done(): - toCloseStreams = append(toCloseStreams, s) - default: - } - return true - }) - for _, s := range toCloseStreams { - s.ctxDoneCallback(s.ctx) - } -} diff --git a/pkg/remote/trans/ttstream/context.go b/pkg/remote/trans/ttstream/context.go deleted file mode 100644 index acad65634f..0000000000 --- a/pkg/remote/trans/ttstream/context.go +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -import ( - "context" - "sync/atomic" -) - -// contextWithCancelReason implements context.Context -// with a cancel func for passing cancel reason -// NOTE: use context.WithCancelCause when go1.20? -type contextWithCancelReason struct { - context.Context - - cancel context.CancelFunc - reason atomic.Value -} - -func (c *contextWithCancelReason) Err() error { - err := c.reason.Load() - if err != nil { - return err.(error) - } - return c.Context.Err() -} - -func (c *contextWithCancelReason) CancelWithReason(reason error) { - if reason != nil { - c.reason.CompareAndSwap(nil, reason) - } - c.cancel() -} - -type cancelWithReason func(reason error) - -func newContextWithCancelReason(ctx context.Context, cancel context.CancelFunc) (context.Context, cancelWithReason) { - ret := &contextWithCancelReason{Context: ctx, cancel: cancel} - return ret, ret.CancelWithReason -} diff --git a/pkg/remote/trans/ttstream/stream_client.go b/pkg/remote/trans/ttstream/stream_client.go index 370788950a..9102f9bf14 100644 --- a/pkg/remote/trans/ttstream/stream_client.go +++ b/pkg/remote/trans/ttstream/stream_client.go @@ -63,6 +63,7 @@ type clientStream struct { // exception as Recv closeStreamException atomic.Value // type must be of *Exception storeExceptionOnce sync.Once + ctxCleanup func() bool // for Header()/Trailer() headerSig chan int32 @@ -206,6 +207,10 @@ func (s *clientStream) close(exception error, sendRst bool, cancelPath string, t s.sendTrailer(nil) } s.runCloseCallback(exception) + // invoke stop func of ctx.AfterFunc to avoid leak + if s.ctxCleanup != nil { + s.ctxCleanup() + } } func (s *clientStream) closeSignalMeta(trailer streaming.Trailer) { @@ -239,6 +244,10 @@ func (s *clientStream) setTraceController(traceCtl *rpcinfo.TraceController) { s.traceCtl = traceCtl } +func (s *clientStream) setCtxCleanup(clean func() bool) { + s.ctxCleanup = clean +} + func (s *clientStream) handleStreamStartEvent(event rpcinfo.StreamStartEvent) { if s.traceCtl == nil { return diff --git a/pkg/remote/trans/ttstream/stream_client_test.go b/pkg/remote/trans/ttstream/stream_client_test.go index 7ab350668c..ea781b2a2e 100644 --- a/pkg/remote/trans/ttstream/stream_client_test.go +++ b/pkg/remote/trans/ttstream/stream_client_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + internal_stream "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streaming" @@ -109,8 +110,8 @@ func Test_clientStream_parseCtxErr(t *testing.T) { ctxFunc: func() (context.Context, context.CancelFunc) { ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), rpcinfo.NewRPCInfo( rpcinfo.NewEndpointInfo("serviceB", "testMethod", nil, nil), nil, nil, rpcinfo.NewRPCConfig(), nil)) - ctx, cancel := context.WithCancel(ctx) - ctx, cancelFunc := newContextWithCancelReason(ctx, cancel) + ctx, cancelFunc := context.WithCancelCause(ctx) + ctx = internal_stream.NewContextWithCancelReason(ctx) cause := defaultRstException return ctx, func() { cancelFunc(errUpstreamCancel.newBuilder().withSide(serverSide).setOrAppendCancelPath("serviceA").withCauseAndTypeId(cause, cause.TypeId())) @@ -156,8 +157,8 @@ func Test_clientStream_parseCtxErr(t *testing.T) { ctxFunc: func() (context.Context, context.CancelFunc) { ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), rpcinfo.NewRPCInfo( rpcinfo.NewEndpointInfo("serviceA", "testMethod", nil, nil), nil, nil, rpcinfo.NewRPCConfig(), nil)) - ctx, cancel := context.WithCancel(ctx) - ctx, cancelFunc := newContextWithCancelReason(ctx, cancel) + ctx, cancelFunc := context.WithCancelCause(ctx) + ctx = internal_stream.NewContextWithCancelReason(ctx) return ctx, func() { cancelFunc(errors.New("test")) } diff --git a/pkg/remote/trans/ttstream/stream_server.go b/pkg/remote/trans/ttstream/stream_server.go index 14d5e3b7bd..b3d73f675f 100644 --- a/pkg/remote/trans/ttstream/stream_server.go +++ b/pkg/remote/trans/ttstream/stream_server.go @@ -50,7 +50,7 @@ func newServerStream(ctx context.Context, writer streamWriter, smeta streamFrame type serverStream struct { *stream state int32 - cancelFunc cancelWithReason + cancelFunc context.CancelCauseFunc } func (s *serverStream) SetHeader(hd streaming.Header) error { diff --git a/pkg/remote/trans/ttstream/stream_server_test.go b/pkg/remote/trans/ttstream/stream_server_test.go index fcab6b5431..1bdbf5f1b9 100644 --- a/pkg/remote/trans/ttstream/stream_server_test.go +++ b/pkg/remote/trans/ttstream/stream_server_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + internal_stream "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/streaming" ) @@ -34,8 +35,8 @@ func newTestServerStream() *serverStream { } func newTestServerStreamWithStreamWriter(w streamWriter) *serverStream { - ctx, cancel := context.WithCancel(context.Background()) - ctx, cancelFunc := newContextWithCancelReason(ctx, cancel) + ctx, cancelFunc := context.WithCancelCause(context.Background()) + ctx = internal_stream.NewContextWithCancelReason(ctx) srvSt := newServerStream(ctx, w, streamFrame{}) srvSt.cancelFunc = cancelFunc return srvSt diff --git a/pkg/remote/trans/ttstream/transport_client.go b/pkg/remote/trans/ttstream/transport_client.go index ede3c06382..511e2bf04e 100644 --- a/pkg/remote/trans/ttstream/transport_client.go +++ b/pkg/remote/trans/ttstream/transport_client.go @@ -21,7 +21,6 @@ import ( "net" "sync" "sync/atomic" - "time" "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/netpoll" @@ -29,17 +28,8 @@ import ( "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streaming" - "github.com/cloudwego/kitex/pkg/utils" ) -// ticker is used to manage cleaning canceled stream task. -// it triggers and cleans up actively cancelled streams every 5s. -// Streaming QPS is generally not too high so that using the Sync SharedTicker to reduce -// the overhead of goroutines in a multi-connection scenario. -// -// This is a workaround: when the minimum Go version supports 1.21, use `context.AfterFunc` instead. -var ticker = utils.NewSyncSharedTicker(5 * time.Second) - type clientTransport struct { conn netpoll.Connection pool transPool @@ -83,9 +73,6 @@ func newClientTransport(conn netpoll.Connection, pool transPool) *clientTranspor err = t.loopRead() }, gofunc.NewBasicInfo("", addr)) - // add to stream cleanup ticker - ticker.Add(t) - return t } @@ -130,9 +117,6 @@ func (t *clientTransport) releaseResources(err error) { if cErr := t.conn.Close(); cErr != nil { klog.Infof("KITEX: ttstream clientTransport Close Connection failed, err: %v", cErr) } - - // remove cleanup stream task from ticker to avoid goroutine leak - ticker.Delete(t) } // WaitClosed waits for send loop and recv loop closed diff --git a/pkg/remote/trans/ttstream/transport_server.go b/pkg/remote/trans/ttstream/transport_server.go index 7c98176a10..26e45902c3 100644 --- a/pkg/remote/trans/ttstream/transport_server.go +++ b/pkg/remote/trans/ttstream/transport_server.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/netpoll" + internal_stream "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/ttstream/container" @@ -167,10 +168,10 @@ func (t *serverTransport) readFrame(reader bufiox.Reader) error { var s *serverStream if fr.typ == headerFrameType { // server recv a header frame, we should create a new stream - ctx, cancel := context.WithCancel(context.Background()) - ctx, cFunc := newContextWithCancelReason(ctx, cancel) + ctx, cancelFunc := context.WithCancelCause(context.Background()) + ctx = internal_stream.NewContextWithCancelReason(ctx) s = newServerStream(ctx, t, fr.streamFrame) - s.cancelFunc = cFunc + s.cancelFunc = cancelFunc t.storeStream(s) err = t.spipe.Write(context.Background(), s) } else { diff --git a/pkg/remote/trans/ttstream/transport_test.go b/pkg/remote/trans/ttstream/transport_test.go index 73f8c68ac8..231f38a4bd 100644 --- a/pkg/remote/trans/ttstream/transport_test.go +++ b/pkg/remote/trans/ttstream/transport_test.go @@ -40,7 +40,6 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streaming" - "github.com/cloudwego/kitex/pkg/utils" ) var testServiceInfo = &serviceinfo.ServiceInfo{ @@ -58,12 +57,6 @@ var testServiceInfo = &serviceinfo.ServiceInfo{ }, } -func TestMain(m *testing.M) { - // reduce cleanup interval to speed up unit tests - ticker = utils.NewSyncSharedTicker(50 * time.Millisecond) - m.Run() -} - func TestTransportBasic(t *testing.T) { cfd, sfd := netpoll.GetSysFdPairs() cconn, err := netpoll.NewFDConnection(cfd) @@ -574,6 +567,10 @@ func initTestStreams(t *testing.T, cCtx context.Context, method, cliNodeName, sr cs := newClientStream(cCtx, ctrans, streamFrame{sid: genStreamID(), method: method}) cs.rpcInfo = rpcinfo.NewRPCInfo( rpcinfo.NewEndpointInfo(cliNodeName, method, nil, nil), nil, nil, nil, nil) + stop := context.AfterFunc(cs.ctx, func() { + cs.ctxDoneCallback(cs.ctx) + }) + cs.ctxCleanup = stop err = ctrans.WriteStream(cCtx, cs, intHeader, strHeader) test.Assert(t, err == nil, err) strans := newServerTransport(sconn)