Skip to content

Commit 1b7f8e9

Browse files
committed
Add HTTP proxy support for tunnel connections
1 parent 1cedefa commit 1b7f8e9

File tree

5 files changed

+198
-22
lines changed

5 files changed

+198
-22
lines changed

ingress/origin_dialer.go

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/rs/zerolog"
12+
"golang.org/x/net/proxy"
1213
)
1314

1415
// OriginTCPDialer provides a TCP dial operation to a requested address.
@@ -115,20 +116,49 @@ func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) {
115116
}
116117

117118
type Dialer struct {
118-
Dialer net.Dialer
119+
Dialer proxy.Dialer
119120
}
120121

121122
func NewDialer(config WarpRoutingConfig) *Dialer {
123+
// Create proxy-aware dialer for warp routing
124+
proxyDialer := createProxyDialer(config.ConnectTimeout.Duration, config.TCPKeepAlive.Duration, nil)
122125
return &Dialer{
123-
Dialer: net.Dialer{
124-
Timeout: config.ConnectTimeout.Duration,
125-
KeepAlive: config.TCPKeepAlive.Duration,
126-
},
126+
Dialer: proxyDialer,
127127
}
128128
}
129129

130+
// createProxyDialer creates a proxy.Dialer that respects proxy environment variables
131+
func createProxyDialer(timeout, keepAlive time.Duration, logger *zerolog.Logger) proxy.Dialer {
132+
baseDialer := &net.Dialer{
133+
Timeout: timeout,
134+
KeepAlive: keepAlive,
135+
}
136+
137+
// Check for SOCKS proxy first using golang.org/x/net/proxy
138+
if proxyDialer := proxy.FromEnvironmentUsing(baseDialer); proxyDialer != baseDialer {
139+
if logger != nil {
140+
logger.Debug().Msg("proxy: using SOCKS proxy from environment")
141+
}
142+
return proxyDialer
143+
}
144+
145+
// Fall back to direct connection if no proxy configured
146+
if logger != nil {
147+
logger.Debug().Msg("proxy: no SOCKS proxy configured, using direct connection")
148+
}
149+
return baseDialer
150+
}
151+
130152
func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) {
131-
conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String())
153+
var conn net.Conn
154+
var err error
155+
156+
if contextDialer, ok := d.Dialer.(proxy.ContextDialer); ok {
157+
conn, err = contextDialer.DialContext(ctx, "tcp", dest.String())
158+
} else {
159+
conn, err = d.Dialer.Dial("tcp", dest.String())
160+
}
161+
132162
if err != nil {
133163
return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err)
134164
}

ingress/origin_proxy.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99

1010
"github.com/rs/zerolog"
11+
"golang.org/x/net/proxy"
1112
)
1213

1314
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@@ -86,7 +87,15 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
8687
}
8788

8889
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
89-
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
90+
var conn net.Conn
91+
var err error
92+
93+
if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok {
94+
conn, err = contextDialer.DialContext(ctx, "tcp", dest)
95+
} else {
96+
conn, err = o.dialer.Dial("tcp", dest)
97+
}
98+
9099
if err != nil {
91100
return nil, err
92101
}
@@ -105,7 +114,13 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string,
105114
dest = o.dest
106115
}
107116

108-
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
117+
var conn net.Conn
118+
if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok {
119+
conn, err = contextDialer.DialContext(ctx, "tcp", dest)
120+
} else {
121+
conn, err = o.dialer.Dial("tcp", dest)
122+
}
123+
109124
if err != nil {
110125
return nil, err
111126
}

ingress/origin_proxy_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http/httptest"
1010
"net/url"
1111
"testing"
12+
"time"
1213

1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
@@ -24,7 +25,10 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
2425
listenerClosed := make(chan struct{})
2526
tcpListenRoutine(originListener, listenerClosed)
2627

27-
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
28+
rawTCPService := &rawTCPService{
29+
name: ServiceWarpRouting,
30+
dialer: &net.Dialer{Timeout: 30 * time.Second},
31+
}
2832

2933
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
3034
require.NoError(t, err)

ingress/origin_service.go

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ingress
22

33
import (
4+
"bufio"
45
"context"
56
"crypto/tls"
67
"encoding/json"
@@ -9,11 +10,13 @@ import (
910
"net"
1011
"net/http"
1112
"net/url"
13+
"os"
1214
"strconv"
1315
"time"
1416

1517
"github.com/pkg/errors"
1618
"github.com/rs/zerolog"
19+
"golang.org/x/net/proxy"
1720

1821
"github.com/cloudflare/cloudflared/hello"
1922
"github.com/cloudflare/cloudflared/ipaccess"
@@ -97,7 +100,7 @@ func (o httpService) MarshalJSON() ([]byte, error) {
97100
// It's used by warp routing
98101
type rawTCPService struct {
99102
name string
100-
dialer net.Dialer
103+
dialer proxy.Dialer
101104
writeTimeout time.Duration
102105
logger *zerolog.Logger
103106
}
@@ -114,14 +117,136 @@ func (o rawTCPService) MarshalJSON() ([]byte, error) {
114117
return json.Marshal(o.String())
115118
}
116119

120+
// proxyAwareDialer wraps net.Dialer with proxy support for both HTTP CONNECT and SOCKS
121+
type proxyAwareDialer struct {
122+
baseDialer *net.Dialer
123+
logger *zerolog.Logger
124+
}
125+
126+
// newProxyAwareDialer creates a dialer that supports proxy settings from environment
127+
func newProxyAwareDialer(timeout, keepAlive time.Duration, logger *zerolog.Logger) proxy.Dialer {
128+
baseDialer := &net.Dialer{
129+
Timeout: timeout,
130+
KeepAlive: keepAlive,
131+
}
132+
133+
httpProxy := getEnvProxy("HTTP_PROXY", "http_proxy")
134+
httpsProxy := getEnvProxy("HTTPS_PROXY", "https_proxy")
135+
136+
if httpProxy == "" && httpsProxy == "" {
137+
if logger != nil {
138+
logger.Debug().Msg("proxy: no proxy configured, using direct connection")
139+
}
140+
return baseDialer
141+
}
142+
143+
if logger != nil {
144+
logger.Debug().Str("HTTP_PROXY", httpProxy).Str("HTTPS_PROXY", httpsProxy).Msg("proxy: detected proxy configuration")
145+
}
146+
return &proxyAwareDialer{
147+
baseDialer: baseDialer,
148+
logger: logger,
149+
}
150+
}
151+
152+
func getEnvProxy(upper, lower string) string {
153+
if v := os.Getenv(upper); v != "" {
154+
return v
155+
}
156+
return os.Getenv(lower)
157+
}
158+
159+
func (p *proxyAwareDialer) Dial(network, addr string) (net.Conn, error) {
160+
return p.DialContext(context.Background(), network, addr)
161+
}
162+
163+
func (p *proxyAwareDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
164+
if network != "tcp" {
165+
return p.baseDialer.DialContext(ctx, network, addr)
166+
}
167+
168+
req := &http.Request{URL: &url.URL{Scheme: "http", Host: addr}}
169+
proxyURL, err := http.ProxyFromEnvironment(req)
170+
if err != nil || proxyURL == nil {
171+
if p.logger != nil {
172+
p.logger.Debug().Str("addr", addr).Msg("proxy: direct connection to")
173+
}
174+
return p.baseDialer.DialContext(ctx, network, addr)
175+
}
176+
177+
if p.logger != nil {
178+
p.logger.Debug().Str("proxy_url", proxyURL.String()).Str("addr", addr).Msg("proxy: using proxy")
179+
}
180+
181+
switch proxyURL.Scheme {
182+
case "socks4", "socks5":
183+
return p.dialSOCKS(ctx, proxyURL, network, addr)
184+
case "http", "https":
185+
return p.dialHTTPConnect(ctx, proxyURL, addr)
186+
default:
187+
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
188+
}
189+
}
190+
191+
func (p *proxyAwareDialer) dialSOCKS(ctx context.Context, proxyURL *url.URL, network, addr string) (net.Conn, error) {
192+
socksDialer, err := proxy.FromURL(proxyURL, p.baseDialer)
193+
if err != nil {
194+
return nil, fmt.Errorf("SOCKS proxy error: %w", err)
195+
}
196+
197+
if contextDialer, ok := socksDialer.(proxy.ContextDialer); ok {
198+
return contextDialer.DialContext(ctx, network, addr)
199+
}
200+
return socksDialer.Dial(network, addr)
201+
}
202+
203+
func (p *proxyAwareDialer) dialHTTPConnect(ctx context.Context, proxyURL *url.URL, addr string) (net.Conn, error) {
204+
proxyAddr := proxyURL.Host
205+
if proxyURL.Port() == "" {
206+
if proxyURL.Scheme == "https" {
207+
proxyAddr = net.JoinHostPort(proxyURL.Hostname(), "443")
208+
} else {
209+
proxyAddr = net.JoinHostPort(proxyURL.Hostname(), "80")
210+
}
211+
}
212+
213+
conn, err := p.baseDialer.DialContext(ctx, "tcp", proxyAddr)
214+
if err != nil {
215+
return nil, fmt.Errorf("proxy connection failed: %w", err)
216+
}
217+
218+
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", addr, addr)
219+
if _, err := conn.Write([]byte(connectReq)); err != nil {
220+
conn.Close()
221+
return nil, fmt.Errorf("CONNECT request failed: %w", err)
222+
}
223+
224+
br := bufio.NewReader(conn)
225+
resp, err := http.ReadResponse(br, &http.Request{Method: "CONNECT"})
226+
if err != nil {
227+
conn.Close()
228+
return nil, fmt.Errorf("CONNECT response failed: %w", err)
229+
}
230+
resp.Body.Close()
231+
232+
if resp.StatusCode != 200 {
233+
conn.Close()
234+
return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
235+
}
236+
237+
if p.logger != nil {
238+
p.logger.Debug().Str("addr", addr).Msg("proxy: HTTP CONNECT successful")
239+
}
240+
return conn, nil
241+
}
242+
117243
// tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as
118-
// cloudflared access commands.
119244
type tcpOverWSService struct {
120245
scheme string
121246
dest string
122247
isBastion bool
123248
streamHandler streamHandlerFunc
124-
dialer net.Dialer
249+
dialer proxy.Dialer
125250
}
126251

127252
type socksProxyOverWSService struct {
@@ -142,12 +267,14 @@ func newTCPOverWSService(url *url.URL) *tcpOverWSService {
142267
return &tcpOverWSService{
143268
scheme: url.Scheme,
144269
dest: url.Host,
270+
dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil),
145271
}
146272
}
147273

148274
func newBastionService() *tcpOverWSService {
149275
return &tcpOverWSService{
150276
isBastion: true,
277+
dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil),
151278
}
152279
}
153280

@@ -187,8 +314,8 @@ func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg Ori
187314
} else {
188315
o.streamHandler = DefaultStreamHandler
189316
}
190-
o.dialer.Timeout = cfg.ConnectTimeout.Duration
191-
o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration
317+
// Recreate dialer with new timeout and keepalive settings
318+
o.dialer = newProxyAwareDialer(cfg.ConnectTimeout.Duration, cfg.TCPKeepAlive.Duration, log)
192319
return nil
193320
}
194321

@@ -291,11 +418,8 @@ type WarpRoutingService struct {
291418

292419
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
293420
svc := &rawTCPService{
294-
name: ServiceWarpRouting,
295-
dialer: net.Dialer{
296-
Timeout: config.ConnectTimeout.Duration,
297-
KeepAlive: config.TCPKeepAlive.Duration,
298-
},
421+
name: ServiceWarpRouting,
422+
dialer: newProxyAwareDialer(config.ConnectTimeout.Duration, config.TCPKeepAlive.Duration, nil),
299423
writeTimeout: writeTimeout,
300424
}
301425

ingress/origins/dns.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,18 @@ func (r *resolver) peekDial(ctx context.Context, network, address string) (net.C
205205

206206
// NewDNSDialer creates a custom dialer for the DNS resolver service to utilize.
207207
func NewDNSDialer() *ingress.Dialer {
208-
return &ingress.Dialer{
209-
Dialer: net.Dialer{
208+
// For DNS, use direct connection to avoid circular dependencies
209+
netDialer := &net.Dialer{
210210
// We want short timeouts for the DNS requests
211211
Timeout: 5 * time.Second,
212212
// We do not want keep alive since the edge will not reuse TCP connections per request
213213
KeepAlive: -1,
214214
KeepAliveConfig: net.KeepAliveConfig{
215215
Enable: false,
216216
},
217-
},
217+
}
218+
219+
return &ingress.Dialer{
220+
Dialer: netDialer,
218221
}
219222
}

0 commit comments

Comments
 (0)