Skip to content

Commit e2915c6

Browse files
committed
let parent process clean up
1 parent 52e236f commit e2915c6

File tree

4 files changed

+48
-59
lines changed

4 files changed

+48
-59
lines changed

app/server.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ import (
77
"os"
88
"os/signal"
99
"runtime/debug"
10-
"sync"
1110
"syscall"
1211

1312
"github.com/mushorg/glutton"
13+
1414
"github.com/spf13/pflag"
1515
"github.com/spf13/viper"
1616
)
@@ -32,7 +32,7 @@ func main() {
3232
\_____|_|\__,_|\__|\__\___/|_| |_|
3333
3434
`)
35-
fmt.Printf("%s %s\n", VERSION, BUILDDATE)
35+
fmt.Printf("%s %s\n\n", VERSION, BUILDDATE)
3636

3737
pflag.StringP("interface", "i", "eth0", "Bind to this interface")
3838
pflag.IntP("ssh", "s", 0, "Override SSH port")
@@ -53,39 +53,37 @@ func main() {
5353
return
5454
}
5555

56-
gtn, err := glutton.New(context.Background())
56+
g, err := glutton.New(context.Background())
5757
if err != nil {
5858
log.Fatal(err)
5959
}
6060

61-
if err := gtn.Init(); err != nil {
61+
if err := g.Init(); err != nil {
6262
log.Fatal("Failed to initialize Glutton:", err)
6363
}
6464

65-
exitMtx := sync.RWMutex{}
6665
exit := func() {
6766
// See if there was a panic...
6867
if r := recover(); r != nil {
6968
fmt.Fprintln(os.Stderr, r)
7069
fmt.Println("stacktrace from panic: \n" + string(debug.Stack()))
7170
}
72-
exitMtx.Lock()
73-
gtn.Shutdown()
74-
exitMtx.Unlock()
71+
g.Shutdown()
7572
}
76-
defer exit()
7773

7874
// capture and handle signals
7975
sig := make(chan os.Signal, 1)
8076
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
8177
go func() {
8278
<-sig
79+
fmt.Print("\r")
8380
exit()
84-
fmt.Println("\nleaving...")
81+
fmt.Println()
8582
os.Exit(0)
8683
}()
8784

88-
if err := gtn.Start(); err != nil {
85+
if err := g.Start(); err != nil {
86+
exit()
8987
log.Fatal("Failed to start Glutton server:", err)
9088
}
9189
}

glutton.go

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -148,55 +148,71 @@ func (g *Glutton) Init() error {
148148
}
149149

150150
func (g *Glutton) udpListen(wg *sync.WaitGroup) {
151-
defer wg.Done()
151+
defer func() {
152+
wg.Done()
153+
}()
152154
buffer := make([]byte, 1024)
153155
for {
154-
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpListener, buffer)
156+
select {
157+
case <-g.ctx.Done():
158+
if err := g.Server.udpConn.Close(); err != nil {
159+
g.Logger.Error("Failed to close UDP listener", producer.ErrAttr(err))
160+
}
161+
return
162+
default:
163+
}
164+
g.Server.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
165+
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpConn, buffer)
155166
if err != nil {
156-
g.Logger.Error("failed to read UDP packet", producer.ErrAttr(err))
167+
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
168+
continue
169+
}
170+
g.Logger.Error("Failed to read UDP packet", producer.ErrAttr(err))
157171
}
158172

159173
rule, err := g.applyRules("udp", srcAddr, dstAddr)
160174
if err != nil {
161-
g.Logger.Error("failed to apply rules", producer.ErrAttr(err))
175+
g.Logger.Error("Failed to apply rules", producer.ErrAttr(err))
162176
}
163177
if rule == nil {
164178
rule = &rules.Rule{Target: "udp"}
165179
}
166180
md, err := g.connTable.Register(srcAddr.IP.String(), strconv.Itoa(int(srcAddr.AddrPort().Port())), dstAddr.AddrPort().Port(), rule)
167181
if err != nil {
168-
g.Logger.Error("failed to register UDP packet", producer.ErrAttr(err))
182+
g.Logger.Error("Failed to register UDP packet", producer.ErrAttr(err))
169183
}
170184

171185
if hfunc, ok := g.udpProtocolHandlers[rule.Target]; ok {
172186
data := buffer[:n]
173187
go func() {
174188
if err := hfunc(g.ctx, srcAddr, dstAddr, data, md); err != nil {
175-
g.Logger.Error("failed to handle UDP payload", producer.ErrAttr(err))
189+
g.Logger.Error("Failed to handle UDP payload", producer.ErrAttr(err))
176190
}
177191
}()
178192
}
179193
}
180194
}
181195

182-
func (g *Glutton) tcpListen(wg *sync.WaitGroup) {
183-
defer wg.Done()
196+
func (g *Glutton) tcpListen() {
184197
for {
185198
select {
186199
case <-g.ctx.Done():
200+
if err := g.Server.tcpListener.Close(); err != nil {
201+
g.Logger.Error("Failed to close TCP listener", producer.ErrAttr(err))
202+
}
187203
return
188204
default:
189205
}
190206

191207
conn, err := g.Server.tcpListener.Accept()
192208
if err != nil {
193-
g.Logger.Error("failed to accept connection", producer.ErrAttr(err))
209+
g.Logger.Error("Failed to accept connection", producer.ErrAttr(err))
194210
continue
195211
}
196212

197213
rule, err := g.applyRulesOnConn(conn)
198214
if err != nil {
199-
g.Logger.Error("failed to apply rules", producer.ErrAttr(err))
215+
g.Logger.Error("Failed to apply rules", producer.ErrAttr(err))
200216
continue
201217
}
202218
if rule == nil {
@@ -205,21 +221,21 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) {
205221

206222
md, err := g.connTable.RegisterConn(conn, rule)
207223
if err != nil {
208-
g.Logger.Error("failed to register connection", producer.ErrAttr(err))
224+
g.Logger.Error("Failed to register connection", producer.ErrAttr(err))
209225
continue
210226
}
211227

212228
g.Logger.Debug("new connection", slog.String("addr", conn.LocalAddr().String()), slog.String("handler", rule.Target))
213229

214230
g.ctx = context.WithValue(g.ctx, ctxTimeout("timeout"), int64(viper.GetInt("conn_timeout")))
215231
if err := g.UpdateConnectionTimeout(g.ctx, conn); err != nil {
216-
g.Logger.Error("failed to set connection timeout", producer.ErrAttr(err))
232+
g.Logger.Error("Failed to set connection timeout", producer.ErrAttr(err))
217233
}
218234

219235
if hfunc, ok := g.tcpProtocolHandlers[rule.Target]; ok {
220236
go func() {
221237
if err := hfunc(g.ctx, conn, md); err != nil {
222-
g.Logger.Error("failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target))
238+
g.Logger.Error("Failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target))
223239
}
224240
}()
225241
}
@@ -228,13 +244,7 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) {
228244

229245
// Start the listener, this blocks for new connections
230246
func (g *Glutton) Start() error {
231-
quit := make(chan struct{}) // stop monitor on shutdown
232-
defer func() {
233-
quit <- struct{}{}
234-
g.Shutdown()
235-
}()
236-
237-
g.startMonitor(quit)
247+
g.startMonitor()
238248

239249
sshPort := viper.GetUint32("ports.ssh")
240250
if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), sshPort); err != nil {
@@ -249,9 +259,7 @@ func (g *Glutton) Start() error {
249259

250260
wg.Add(1)
251261
go g.udpListen(wg)
252-
253-
wg.Add(1)
254-
go g.tcpListen(wg)
262+
go g.tcpListen()
255263

256264
wg.Wait()
257265

@@ -350,18 +358,13 @@ func (g *Glutton) ProduceUDP(handler string, srcAddr, dstAddr *net.UDPAddr, md c
350358
func (g *Glutton) Shutdown() {
351359
g.cancel() // close all connection
352360

353-
g.Logger.Info("Shutting down listeners")
354-
if err := g.Server.Shutdown(); err != nil {
355-
g.Logger.Error("failed to shutdown server", producer.ErrAttr(err))
356-
}
357-
358-
g.Logger.Info("FLushing TCP iptables")
361+
g.Logger.Info("Flushing TCP iptables")
359362
if err := flushTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), uint32(viper.GetInt("ports.ssh"))); err != nil {
360-
g.Logger.Error("failed to drop tcp iptables", producer.ErrAttr(err))
363+
g.Logger.Error("Failed to drop tcp iptables", producer.ErrAttr(err))
361364
}
362-
g.Logger.Info("FLushing UDP iptables")
365+
g.Logger.Info("Flushing UDP iptables")
363366
if err := flushTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "udp", uint32(g.Server.udpPort), uint32(viper.GetInt("ports.ssh"))); err != nil {
364-
g.Logger.Error("failed to drop udp iptables", producer.ErrAttr(err))
367+
g.Logger.Error("Failed to drop udp iptables", producer.ErrAttr(err))
365368
}
366369

367370
g.Logger.Info("All done")

server.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010

1111
type Server struct {
1212
tcpListener net.Listener
13-
udpListener *net.UDPConn
13+
udpConn *net.UDPConn
1414
tcpPort uint
1515
udpPort uint
1616
}
@@ -31,26 +31,16 @@ func (s *Server) Start() error {
3131
if s.tcpListener, err = tproxy.ListenTCP("tcp4", tcpAddr); err != nil {
3232
return err
3333
}
34+
3435
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", s.udpPort))
3536
if err != nil {
3637
return err
3738
}
38-
if s.udpListener, err = tproxy.ListenUDP("udp4", udpAddr); err != nil {
39+
if s.udpConn, err = tproxy.ListenUDP("udp4", udpAddr); err != nil {
3940
return err
4041
}
41-
if s.udpListener == nil {
42+
if s.udpConn == nil {
4243
return errors.New("nil udp listener")
4344
}
4445
return nil
4546
}
46-
47-
func (s *Server) Shutdown() error {
48-
var err error
49-
if s.tcpListener != nil {
50-
err = s.tcpListener.Close()
51-
}
52-
if s.udpListener != nil {
53-
err = s.udpListener.Close()
54-
}
55-
return err
56-
}

server_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,4 @@ import (
99
func TestServer(t *testing.T) {
1010
server := NewServer(1234, 1235)
1111
require.NotNil(t, server)
12-
err := server.Shutdown()
13-
require.NoError(t, err)
1412
}

0 commit comments

Comments
 (0)