Skip to content

Commit 102e0cd

Browse files
authored
Refactor keepalive logic (ortuman#105)
1 parent 87e8e2b commit 102e0cd

28 files changed

+204
-540
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
55
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
66

7-
## [0.10.0] - 2020-03-18
7+
## [0.10.1] - 2020-03-22
88
### Changed
99
- Set resource limit
1010

c2s/config.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ type TransportConfig struct {
7373
Type transport.Type
7474
BindAddress string
7575
Port int
76-
KeepAlive time.Duration
7776
URLPath string
7877
}
7978

@@ -96,9 +95,6 @@ func (t *TransportConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
9695
case "", "socket":
9796
t.Type = transport.Socket
9897

99-
case "websocket":
100-
t.Type = transport.WebSocket
101-
10298
default:
10399
return fmt.Errorf("c2s.TransportConfig: unrecognized transport type: %s", p.Type)
104100
}
@@ -114,10 +110,6 @@ func (t *TransportConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
114110
if t.Port == 0 {
115111
t.Port = defaultTransportPort
116112
}
117-
t.KeepAlive = time.Duration(p.KeepAlive) * time.Second
118-
if t.KeepAlive == 0 {
119-
t.KeepAlive = defaultTransportKeepAlive
120-
}
121113
return nil
122114
}
123115

@@ -132,6 +124,7 @@ type Config struct {
132124
ID string
133125
ConnectTimeout time.Duration
134126
Timeout time.Duration
127+
KeepAlive time.Duration
135128
MaxStanzaSize int
136129
ResourceConflict ResourceConflictPolicy
137130
Transport TransportConfig
@@ -145,6 +138,7 @@ type configProxy struct {
145138
TLS TLSConfig `yaml:"tls"`
146139
ConnectTimeout int `yaml:"connect_timeout"`
147140
Timeout int `yaml:"timeout"`
141+
KeepAlive int `yaml:"keep_alive"`
148142
MaxStanzaSize int `yaml:"max_stanza_size"`
149143
ResourceConflict string `yaml:"resource_conflict"`
150144
Transport TransportConfig `yaml:"transport"`
@@ -167,6 +161,10 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
167161
if cfg.Timeout == 0 {
168162
cfg.Timeout = defaultTimeout
169163
}
164+
cfg.KeepAlive = time.Duration(p.KeepAlive) * time.Second
165+
if cfg.KeepAlive == 0 {
166+
cfg.KeepAlive = defaultTransportKeepAlive
167+
}
170168
cfg.MaxStanzaSize = p.MaxStanzaSize
171169
if cfg.MaxStanzaSize == 0 {
172170
cfg.MaxStanzaSize = defaultMaxStanzaSize
@@ -200,9 +198,9 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
200198
}
201199

202200
type streamConfig struct {
203-
transport transport.Transport
204201
connectTimeout time.Duration
205202
timeout time.Duration
203+
keepAlive time.Duration
206204
maxStanzaSize int
207205
resourceConflict ResourceConflictPolicy
208206
sasl []string

c2s/config_test.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ package c2s
88
import (
99
"os"
1010
"testing"
11-
"time"
1211

1312
"github.com/ortuman/jackal/transport"
1413
"github.com/ortuman/jackal/transport/compress"
@@ -46,14 +45,6 @@ func TestTransportConfig(t *testing.T) {
4645
require.Equal(t, transport.Socket, s.Type)
4746
require.Equal(t, "0.0.0.0", s.BindAddress)
4847
require.Equal(t, 5222, s.Port)
49-
require.Equal(t, time.Second*time.Duration(120), s.KeepAlive)
50-
51-
err = yaml.Unmarshal([]byte("{type: websocket, url_path: /xmpp/ws}"), &s)
52-
require.Nil(t, err)
53-
54-
require.Equal(t, transport.WebSocket, s.Type)
55-
require.Equal(t, 5222, s.Port)
56-
require.Equal(t, time.Second*time.Duration(120), s.KeepAlive)
5748
}
5849

5950
func TestConfig(t *testing.T) {

c2s/in.go

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ type inStream struct {
4646
mods *module.Modules
4747
comps *component.Components
4848
sess *session.Session
49+
tr transport.Transport
50+
mu sync.RWMutex
4951
id string
5052
connectTm *time.Timer
53+
readTimeoutTm *time.Timer
5154
state uint32
5255
authenticators []auth.Authenticator
5356
activeAuth auth.Authenticator
5457
runQueue *runqueue.RunQueue
55-
mu sync.RWMutex
5658
jid *jid.JID
5759
secured bool
5860
compressed bool
@@ -63,10 +65,11 @@ type inStream struct {
6365
ctxCancelFn context.CancelFunc
6466
}
6567

66-
func newStream(id string, config *streamConfig, mods *module.Modules, comps *component.Components, router router.Router, userRep repository.User, blockListRep repository.BlockList) stream.C2S {
68+
func newStream(id string, config *streamConfig, tr transport.Transport, mods *module.Modules, comps *component.Components, router router.Router, userRep repository.User, blockListRep repository.BlockList) stream.C2S {
6769
ctx, ctxCancelFn := context.WithCancel(context.Background())
6870
s := &inStream{
6971
cfg: config,
72+
tr: tr,
7073
router: router,
7174
userRep: userRep,
7275
blockListRep: blockListRep,
@@ -79,7 +82,7 @@ func newStream(id string, config *streamConfig, mods *module.Modules, comps *com
7982
}
8083

8184
// initialize stream context
82-
secured := !(config.transport.Type() == transport.Socket)
85+
secured := !(tr.Type() == transport.Socket)
8386
s.setSecured(secured)
8487
s.setJID(&jid.JID{})
8588

@@ -179,14 +182,13 @@ func (s *inStream) Disconnect(ctx context.Context, err error) {
179182
waitCh := make(chan struct{})
180183
s.runQueue.Run(func() {
181184
s.disconnect(ctx, err)
182-
s.ctxCancelFn()
183185
close(waitCh)
184186
})
185187
<-waitCh
186188
}
187189

188190
func (s *inStream) initializeAuthenticators() {
189-
tr := s.cfg.transport
191+
tr := s.tr
190192
var authenticators []auth.Authenticator
191193
for _, a := range s.cfg.sasl {
192194
switch a {
@@ -266,7 +268,7 @@ func (s *inStream) handleConnecting(ctx context.Context, elem xmpp.XElement) {
266268
func (s *inStream) unauthenticatedFeatures() []xmpp.XElement {
267269
var features []xmpp.XElement
268270

269-
isSocketTr := s.cfg.transport.Type() == transport.Socket
271+
isSocketTr := s.tr.Type() == transport.Socket
270272

271273
if isSocketTr && !s.IsSecured() {
272274
startTLS := xmpp.NewElementName("starttls")
@@ -302,7 +304,7 @@ func (s *inStream) unauthenticatedFeatures() []xmpp.XElement {
302304
func (s *inStream) authenticatedFeatures() []xmpp.XElement {
303305
var features []xmpp.XElement
304306

305-
isSocketTr := s.cfg.transport.Type() == transport.Socket
307+
isSocketTr := s.tr.Type() == transport.Socket
306308

307309
// attach compression feature
308310
compressionAvailable := isSocketTr && s.cfg.compression.Level != compress.NoCompression
@@ -444,7 +446,7 @@ func (s *inStream) proceedStartTLS(ctx context.Context, elem xmpp.XElement) {
444446
s.setSecured(true)
445447
s.writeElement(ctx, xmpp.NewElementNamespace("proceed", tlsNamespace))
446448

447-
s.cfg.transport.StartTLS(&tls.Config{Certificates: s.router.Hosts().Certificates()}, false)
449+
s.tr.StartTLS(&tls.Config{Certificates: s.router.Hosts().Certificates()}, false)
448450

449451
log.Infof("secured stream... id: %s", s.id)
450452
s.restartSession()
@@ -470,7 +472,7 @@ func (s *inStream) compress(ctx context.Context, elem xmpp.XElement) {
470472
}
471473
s.writeElement(ctx, xmpp.NewElementNamespace("compressed", compressProtocolNamespace))
472474

473-
s.cfg.transport.EnableCompression(s.cfg.compression.Level)
475+
s.tr.EnableCompression(s.cfg.compression.Level)
474476
s.setCompressed(true)
475477

476478
log.Infof("compressed stream... id: %s", s.id)
@@ -698,7 +700,9 @@ sendMessage:
698700

699701
// Runs on it's own goroutine
700702
func (s *inStream) doRead() {
703+
s.scheduleReadTimeout()
701704
elem, sErr := s.sess.Receive()
705+
s.cancelReadTimeout()
702706

703707
ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout)
704708
if sErr == nil {
@@ -796,12 +800,14 @@ func (s *inStream) disconnectClosingSession(ctx context.Context, closeSession, u
796800
if unbind {
797801
s.router.Unbind(ctx, s.JID())
798802
}
803+
s.ctxCancelFn()
804+
799805
// notify disconnection
800806
if s.cfg.onDisconnect != nil {
801807
s.cfg.onDisconnect(s)
802808
}
803809
s.setState(disconnected)
804-
_ = s.cfg.transport.Close()
810+
_ = s.tr.Close()
805811

806812
s.runQueue.Stop(nil) // stop processing messages
807813
}
@@ -831,9 +837,8 @@ func (s *inStream) isBlockedJID(ctx context.Context, j *jid.JID) bool {
831837
func (s *inStream) restartSession() {
832838
s.sess = session.New(s.id, &session.Config{
833839
JID: s.JID(),
834-
Transport: s.cfg.transport,
835840
MaxStanzaSize: s.cfg.maxStanzaSize,
836-
}, s.router.Hosts())
841+
}, s.tr, s.router.Hosts())
837842
s.setState(connecting)
838843
}
839844

@@ -885,6 +890,25 @@ func (s *inStream) setSessionStarted(sessStarted bool) {
885890
s.sessStarted = sessStarted
886891
}
887892

893+
func (s *inStream) scheduleReadTimeout() {
894+
s.mu.Lock()
895+
s.readTimeoutTm = time.AfterFunc(s.cfg.keepAlive, s.readTimeout)
896+
s.mu.Unlock()
897+
}
898+
899+
func (s *inStream) cancelReadTimeout() {
900+
s.mu.Lock()
901+
s.readTimeoutTm.Stop()
902+
s.mu.Unlock()
903+
}
904+
905+
func (s *inStream) readTimeout() {
906+
s.runQueue.Run(func() {
907+
ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout)
908+
s.disconnect(ctx, streamerror.ErrConnectionTimeout)
909+
})
910+
}
911+
888912
func (s *inStream) setState(state uint32) {
889913
atomic.StoreUint32(&s.state, state)
890914
}

c2s/in_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,11 @@ func tUtilStreamStartSession(conn *fakeSocketConn, t *testing.T) {
423423

424424
func tUtilStreamInit(r router.Router, userRep repository.User, blockListRep repository.BlockList) (*inStream, *fakeSocketConn) {
425425
conn := newFakeSocketConn()
426-
tr := transport.NewSocketTransport(conn, 4096)
426+
tr := transport.NewSocketTransport(conn)
427427
stm := newStream(
428428
"abcd1234",
429-
tUtilInStreamDefaultConfig(tr),
429+
tUtilInStreamDefaultConfig(),
430+
tr,
430431
tUtilInitModules(r),
431432
&component.Components{},
432433
r,
@@ -435,10 +436,10 @@ func tUtilStreamInit(r router.Router, userRep repository.User, blockListRep repo
435436
return stm.(*inStream), conn
436437
}
437438

438-
func tUtilInStreamDefaultConfig(tr transport.Transport) *streamConfig {
439+
func tUtilInStreamDefaultConfig() *streamConfig {
439440
return &streamConfig{
440441
connectTimeout: time.Second,
441-
transport: tr,
442+
keepAlive: time.Second,
442443
maxStanzaSize: 8192,
443444
resourceConflict: Reject,
444445
compression: CompressConfig{Level: compress.DefaultCompression},

c2s/server.go

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ package c2s
77

88
import (
99
"context"
10-
"crypto/tls"
1110
"fmt"
1211
"net"
1312
"net/http"
1413
"strconv"
1514
"sync"
1615
"sync/atomic"
16+
"time"
1717

1818
"github.com/gorilla/websocket"
1919
"github.com/ortuman/jackal/component"
@@ -67,9 +67,6 @@ func (s *server) start() {
6767
switch s.cfg.Transport.Type {
6868
case transport.Socket:
6969
err = s.listenSocketConn(address)
70-
case transport.WebSocket:
71-
err = s.listenWebSocketConn(address)
72-
break
7370
}
7471
if err != nil {
7572
log.Fatalf("%v", err)
@@ -87,40 +84,13 @@ func (s *server) listenSocketConn(address string) error {
8784
for atomic.LoadUint32(&s.listening) == 1 {
8885
conn, err := ln.Accept()
8986
if err == nil {
90-
go s.startStream(transport.NewSocketTransport(conn, s.cfg.Transport.KeepAlive))
87+
go s.startStream(transport.NewSocketTransport(conn), s.cfg.KeepAlive)
9188
continue
9289
}
9390
}
9491
return nil
9592
}
9693

97-
func (s *server) listenWebSocketConn(address string) error {
98-
http.HandleFunc(s.cfg.Transport.URLPath, s.websocketUpgrade)
99-
100-
s.wsSrv = &http.Server{TLSConfig: &tls.Config{Certificates: s.router.Hosts().Certificates()}}
101-
s.wsUpgrader = &websocket.Upgrader{
102-
Subprotocols: []string{"xmpp"},
103-
CheckOrigin: func(r *http.Request) bool { return r.Header.Get("Sec-WebSocket-Protocol") == "xmpp" },
104-
}
105-
106-
// start listening
107-
ln, err := listenerProvider("tcp", address)
108-
if err != nil {
109-
return err
110-
}
111-
atomic.StoreUint32(&s.listening, 1)
112-
return s.wsSrv.ServeTLS(ln, "", "")
113-
}
114-
115-
func (s *server) websocketUpgrade(w http.ResponseWriter, r *http.Request) {
116-
conn, err := s.wsUpgrader.Upgrade(w, r, nil)
117-
if err != nil {
118-
log.Error(err)
119-
return
120-
}
121-
s.startStream(transport.NewWebSocketTransport(conn, s.cfg.Transport.KeepAlive))
122-
}
123-
12494
func (s *server) shutdown(ctx context.Context) error {
12595
if atomic.CompareAndSwapUint32(&s.listening, 1, 0) {
12696
// stop listening
@@ -129,10 +99,6 @@ func (s *server) shutdown(ctx context.Context) error {
12999
if err := s.ln.Close(); err != nil {
130100
return err
131101
}
132-
case transport.WebSocket:
133-
if err := s.wsSrv.Shutdown(ctx); err != nil {
134-
return err
135-
}
136102
}
137103
// close all connections
138104
c, err := s.closeConnections(ctx)
@@ -144,18 +110,18 @@ func (s *server) shutdown(ctx context.Context) error {
144110
return nil
145111
}
146112

147-
func (s *server) startStream(tr transport.Transport) {
113+
func (s *server) startStream(tr transport.Transport, keepAlive time.Duration) {
148114
cfg := &streamConfig{
149-
transport: tr,
150115
resourceConflict: s.cfg.ResourceConflict,
151116
connectTimeout: s.cfg.ConnectTimeout,
117+
keepAlive: s.cfg.KeepAlive,
152118
timeout: s.cfg.Timeout,
153119
maxStanzaSize: s.cfg.MaxStanzaSize,
154120
sasl: s.cfg.SASL,
155121
compression: s.cfg.Compression,
156122
onDisconnect: s.unregisterStream,
157123
}
158-
stm := newStream(s.nextID(), cfg, s.mods, s.comps, s.router, s.userRep, s.blockListRep)
124+
stm := newStream(s.nextID(), cfg, tr, s.mods, s.comps, s.router, s.userRep, s.blockListRep)
159125
s.registerStream(stm)
160126
}
161127

0 commit comments

Comments
 (0)