/
loadbalancer.go
209 lines (181 loc) · 5.95 KB
/
loadbalancer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
package loadbalancer
import (
"context"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"sync"
"github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus"
"inet.af/tcpproxy"
)
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// LoadBalancer holds data for a local listener which forwards connections to a
// pool of remote servers. It is not a proper load-balancer in that it does not
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy
serviceName string
configFile string
localAddress string
localServerURL string
defaultServerAddress string
ServerURL string
ServerAddresses []string
randomServers []string
servers map[string]*server
currentServerAddress string
nextServerIndex int
Listener net.Listener
}
const RandomPort = 0
var (
SupervisorServiceName = version.Program + "-agent-load-balancer"
APIServerServiceName = version.Program + "-api-server-agent-load-balancer"
ETCDServerServiceName = version.Program + "-etcd-server-load-balancer"
)
// New contstructs a new LoadBalancer instance. The default server URL, and
// currently active servers, are stored in a file within the dataDir.
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
config := net.ListenConfig{Control: reusePort}
var localAddress string
if isIPv6 {
localAddress = fmt.Sprintf("[::1]:%d", lbServerPort)
} else {
localAddress = fmt.Sprintf("127.0.0.1:%d", lbServerPort)
}
listener, err := config.Listen(ctx, "tcp", localAddress)
defer func() {
if _err != nil {
logrus.Warnf("Error starting load balancer: %s", _err)
if listener != nil {
listener.Close()
}
}
}()
if err != nil {
return nil, err
}
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
localAddress = listener.Addr().String()
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
if err != nil {
return nil, err
}
if serverURL == localServerURL {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
defaultServerAddress = ""
}
lb := &LoadBalancer{
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
localAddress: localAddress,
localServerURL: localServerURL,
defaultServerAddress: defaultServerAddress,
servers: make(map[string]*server),
ServerURL: serverURL,
}
lb.setServers([]string{lb.defaultServerAddress})
lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) {
return listener, nil
},
}
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
Addr: serviceName,
DialContext: lb.dialContext,
OnDialError: onDialError,
})
if err := lb.updateConfig(); err != nil {
return nil, err
}
if err := lb.proxy.Start(); err != nil {
return nil, err
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress)
go lb.runHealthChecks(ctx)
return lb, nil
}
func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil {
return
}
if !lb.setServers(serverAddresses) {
return
}
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.ServerAddresses, lb.defaultServerAddress)
if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
}
func (lb *LoadBalancer) LoadBalancerServerURL() string {
if lb == nil {
return ""
}
return lb.localServerURL
}
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress
server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else if server.healthCheck() {
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Debugf("Dial error from load balancer %s: %s", lb.serviceName, err)
}
newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Debugf("Failed over to new server for load balancer %s: %s", lb.serviceName, newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
return nil, errors.New("all servers failed")
}
}
}
func onDialError(src net.Conn, dstDialErr error) {
logrus.Debugf("Incoming conn %s, error dialing load balancer servers: %v", src.RemoteAddr(), dstDialErr)
src.Close()
}
// ResetLoadBalancer will delete the local state file for the load balancer on disk
func ResetLoadBalancer(dataDir, serviceName string) error {
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
if err := os.Remove(stateFile); err != nil {
logrus.Warn(err)
}
return nil
}