-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtcp_handle.go
152 lines (124 loc) · 3.68 KB
/
tcp_handle.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package main
import (
"bufio"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"github.com/sirupsen/logrus"
)
func readRequestHeaderAndBody(conn net.Conn) (string, []byte, error) {
reader := bufio.NewReader(conn)
var requestBuilder strings.Builder
var contentLength int64 = 0
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
return "", nil, fmt.Errorf("error reading client request: %v", err)
}
requestBuilder.WriteString(line)
if strings.HasPrefix(line, "Content-Length:") {
contentLength, _ = strconv.ParseInt(strings.TrimSpace(strings.Split(line, ":")[1]), 10, 64)
}
if err == io.EOF || line == "\r\n" {
break
}
}
var body []byte
if contentLength > 0 {
body = make([]byte, contentLength)
_, err := io.ReadFull(reader, body)
if err != nil {
return "", nil, fmt.Errorf("error reading request body: %v", err)
}
}
return requestBuilder.String(), body, nil
}
// 处理CONNECT请求(HTTPS代理)
func handleConnectRequest_https(conn net.Conn, target, reqLine string) {
hostPort := strings.Split(target, ":")
if len(hostPort) != 2 {
logrus.Errorln("Invalid target format")
return
}
host := hostPort[0]
port := hostPort[1]
proxy_upstream := *proxyAddr
upstream, ForwardMethod := getForwardMethodForHost(proxy_upstream, host, port, "https")
// 调用 forward 函数进行请求转发
forward(upstream, ForwardMethod, reqLine, conn, host)
}
func forward(upstreamHost, forward_method, reqLine string, conn net.Conn, host string) {
if forward_method == "proxy" {
// 尝试连接到目标服务器
upstreamConn, err := net.Dial("tcp", upstreamHost)
if err != nil {
logrus.Errorln("Error connecting to target:", err)
return
}
defer upstreamConn.Close()
// 将客户端的 CONNECT 请求转发给上游代理
_, err = upstreamConn.Write([]byte(reqLine))
if err != nil {
logrus.Errorln("Error forwarding CONNECT to upstream:", err)
return
}
// 读取上游代理的响应
upstream_resp, err := readRequestHeader(upstreamConn)
if err != nil {
logrus.Errorln("readRequestHeader(upstreamConn) error ", err)
return
}
// 转发上游代理的响应给客户端
_, err = conn.Write([]byte(upstream_resp))
if err != nil {
logrus.Errorln("Error forwarding response to client:", err)
return
}
forward_io_copy(conn, upstreamConn, upstreamHost, forward_method, host)
} else if forward_method == "direct" {
targetConn, err := net.Dial("tcp", upstreamHost)
if err != nil {
logrus.Errorln("Error connecting to target:", err)
return
}
_, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil {
logrus.Errorln("Error writing to client:", err)
return
}
//targetConn.SetWriteDeadline(time.Time{}) // 清除写入超时
forward_io_copy(conn, targetConn, upstreamHost, forward_method, host)
}
}
func forward_io_copy(conn, targetConn net.Conn, upstreamHost, forward_method, host string) {
// 使用 channel 和 WaitGroup 来管理双向转发
errCh := make(chan error, 2)
wg := &sync.WaitGroup{}
wg.Add(2)
// 转发 conn -> targetConn
go func() {
defer wg.Done()
written, err := io.Copy(targetConn, conn)
if err != nil {
errCh <- fmt.Errorf("error copying data to upstream: %w", err)
}
forwardedBytes.WithLabelValues("https", host, forward_method).Add(float64(written))
}()
// 转发 targetConn -> conn
go func() {
defer wg.Done()
written, err := io.Copy(conn, targetConn)
if err != nil {
errCh <- fmt.Errorf("error copying data to client: %w", err)
}
forwardedBytes.WithLabelValues("https", host, forward_method).Add(float64(written))
}()
// 等待转发完成
wg.Wait()
targetConn.Close()
conn.Close()
close(errCh)
}