summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/failover/failover.go21
-rw-r--r--internal/failover/monitor.go106
-rw-r--r--internal/failover/ping.go51
-rw-r--r--internal/log/log.go1
-rw-r--r--internal/server/base.go2
-rw-r--r--internal/server/profile.go4
-rw-r--r--internal/server/server.go8
7 files changed, 186 insertions, 7 deletions
diff --git a/internal/failover/failover.go b/internal/failover/failover.go
new file mode 100644
index 0000000..f239eeb
--- /dev/null
+++ b/internal/failover/failover.go
@@ -0,0 +1,21 @@
+package failover
+
+import "time"
+
+const (
+ // Send a ping every 2 seconds to the gateway
+ pInterval time.Duration = 2 * time.Second
+
+ // pAlive is how many pings we need to have sent to check if the connection is alive
+ pAlive int = 3
+
+ // pDropped is how many pings we need to have sent to check if the connection is dropped
+ pDropped int = 5
+)
+
+// New creates a failover monitor for the gateway and the rx bytes function reader
+// This is a simple wrapper over `NewDroppedMonitor` to create one with the default settings
+// If this function returns True, the connection is dropped. False means it has exited and we don't know for sure if it's dropped or not
+func New(readRxBytes func() (int64, error)) (*DroppedConMon, error) {
+ return NewDroppedMonitor(pInterval, pAlive, pDropped, readRxBytes)
+}
diff --git a/internal/failover/monitor.go b/internal/failover/monitor.go
new file mode 100644
index 0000000..d14fb9e
--- /dev/null
+++ b/internal/failover/monitor.go
@@ -0,0 +1,106 @@
+package failover
+
+import (
+ "context"
+ "time"
+
+ "github.com/go-errors/errors"
+)
+
+// The DroppedConMon is a connection monitor that checks for an increase in rx bytes in certain intervals
+type DroppedConMon struct {
+ // pInterval means how the interval in which to send pings
+ pInterval time.Duration
+ // pAlive means how many pings need to be send before checking if the connection is alive
+ pAlive int
+ // pDropped means how many pings need to be send before checking if the connection is dropped
+ pDropped int
+ // The function that reads Rx bytes
+ // If this function returns an error, the monitor exits
+ readRxBytes func() (int64, error)
+ // The cancel context
+ // This is used to cancel the dropped connection monitor
+ cancel context.CancelFunc
+}
+
+func NewDroppedMonitor(pingInterval time.Duration, pAlive int, pDropped int, readRxBytes func() (int64, error)) (*DroppedConMon, error) {
+ if pAlive >= pDropped {
+ return nil, errors.New("pAlive must be smaller than pDropped")
+ }
+ return &DroppedConMon{pInterval: pingInterval, pAlive: pAlive, pDropped: pDropped, readRxBytes: readRxBytes}, nil
+}
+
+// Dropped checks whether or not the connection is 'dropped'
+// In other words, it checks if rx bytes has increased
+func (m *DroppedConMon) dropped(startBytes int64) (bool, error) {
+ b, err := m.readRxBytes()
+ if err != nil {
+ return false, err
+ }
+ return b <= startBytes, nil
+}
+
+// Start starts ticking every ping interval and check if the connection is dropped or alive
+// This does not check Rx bytes every tick, but rather when pAlive or pDropped is reached
+// It returns an error if there was an invalid input or a ping was failed to be sent
+func (m *DroppedConMon) Start(gateway string, mtuSize int) (bool, error) {
+ if mtuSize <= 0 {
+ return false, errors.New("invalid mtu size given")
+ }
+
+ // Create a context and save the cancel function
+ ctx, cancel := context.WithCancel(context.Background())
+ m.cancel = cancel
+ defer m.cancel()
+
+ // Create a ping struct with our mtu size
+ p, err := NewPinger(mtuSize)
+ if err != nil {
+ return false, err
+ }
+
+ // Read the start Rx bytes
+ b, err := m.readRxBytes()
+ if err != nil {
+ return false, err
+ }
+
+ // Create a new ticker that executes our ping function every 'interval' seconds
+ // It starts immediately and stops when we reach the end
+ ticker := time.NewTicker(m.pInterval)
+ defer ticker.Stop()
+
+ // Loop until the max drop counter
+ // We begin with 1 as this is used as the sequence number for ping
+ for s := 1; s <= m.pDropped; s++ {
+ // Send a ping and return if an error occurs
+ if err := p.Send(gateway, s); err != nil {
+ return false, err
+ }
+
+ // Early alive check
+ // If not dropped, return
+ if s == m.pAlive {
+ if d, err := m.dropped(b); !d {
+ return false, err
+ }
+ }
+ // Wait for the next tick to continue
+ select {
+ case <-ticker.C:
+ continue
+ case <-ctx.Done():
+ return false, errors.New("failover was cancelled")
+ }
+ }
+
+ // Dropped check if we have not returned early
+ return m.dropped(b)
+}
+
+// Cancel cancels the dropped connection failover monitor if there is one
+func (m *DroppedConMon) Cancel() {
+ if m.cancel != nil {
+ m.cancel()
+ }
+}
diff --git a/internal/failover/ping.go b/internal/failover/ping.go
new file mode 100644
index 0000000..2ffedd4
--- /dev/null
+++ b/internal/failover/ping.go
@@ -0,0 +1,51 @@
+package failover
+
+import (
+ "fmt"
+ "net"
+ "os"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+
+ "github.com/go-errors/errors"
+)
+
+// mtuOverhead defines the total MTU overhead for an ICMP ECHO message: 20 bytes IP header + 8 bytes ICMP header
+var mtuOverhead = 28
+
+type Pinger struct {
+ listener net.PacketConn
+ buffer []byte
+}
+
+func NewPinger(size int) (*Pinger, error) {
+ l, err := icmp.ListenPacket("udp4", "0.0.0.0")
+ if err != nil {
+ return nil, errors.WrapPrefix(err, "failed creating ping", 0)
+ }
+ return &Pinger{listener: l, buffer: make([]byte, size-mtuOverhead)}, nil
+}
+
+func (p Pinger) Send(gateway string, seq int) error {
+ errorMessage := fmt.Sprintf("failed sending ping, seq %d", seq)
+ // Make a new ICMP message
+ m := icmp.Message{
+ Type: ipv4.ICMPTypeEcho, Code: 0,
+ Body: &icmp.Echo{
+ ID: os.Getpid() & 0xffff, Seq: seq,
+ Data: p.buffer,
+ },
+ }
+ // Marshal the message to bytes
+ b, err := m.Marshal(nil)
+ if err != nil {
+ return errors.WrapPrefix(err, errorMessage, 0)
+ }
+ // And send it to the gateway IP!
+ _, err = p.listener.WriteTo(b, &net.UDPAddr{IP: net.ParseIP(gateway)})
+ if err != nil {
+ return errors.WrapPrefix(err, errorMessage, 0)
+ }
+ return nil
+}
diff --git a/internal/log/log.go b/internal/log/log.go
index be4730f..64d9fb8 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -9,7 +9,6 @@ import (
"path"
"github.com/eduvpn/eduvpn-common/internal/oauth"
-
"github.com/eduvpn/eduvpn-common/internal/util"
"github.com/go-errors/errors"
)
diff --git a/internal/server/base.go b/internal/server/base.go
index 6eb305b..dd15aff 100644
--- a/internal/server/base.go
+++ b/internal/server/base.go
@@ -30,7 +30,7 @@ func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo {
for _, p := range b.Profiles.Info.ProfileList {
// Not a valid profile because it does not support openvpn
// Also the client does not support wireguard
- if !p.supportsOpenVPN() && !wireguardSupport {
+ if !p.SupportsOpenVPN() && !wireguardSupport {
continue
}
valid = append(valid, p)
diff --git a/internal/server/profile.go b/internal/server/profile.go
index 97781e4..d981421 100644
--- a/internal/server/profile.go
+++ b/internal/server/profile.go
@@ -35,10 +35,10 @@ func (profile *Profile) supportsProtocol(protocol string) bool {
return false
}
-func (profile *Profile) supportsWireguard() bool {
+func (profile *Profile) SupportsWireguard() bool {
return profile.supportsProtocol("wireguard")
}
-func (profile *Profile) supportsOpenVPN() bool {
+func (profile *Profile) SupportsOpenVPN() bool {
return profile.supportsProtocol("openvpn")
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 9354883..78f6472 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,6 +1,7 @@
package server
import (
+ "os"
"time"
"github.com/eduvpn/eduvpn-common/internal/oauth"
@@ -219,7 +220,7 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
return false, err
}
// Profile does not support OpenVPN but the client also doesn't support WireGuard
- if !p.supportsOpenVPN() && !wireguardSupport {
+ if !p.SupportsOpenVPN() && !wireguardSupport {
return false, nil
}
return true, nil
@@ -242,8 +243,9 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (string, strin
return "", "", err
}
- ovpn := p.supportsOpenVPN()
- wg := p.supportsWireguard() && wireguardSupport
+ ovpn := p.SupportsOpenVPN()
+ wg := p.SupportsWireguard() && wireguardSupport
+
// If we don't prefer TCP and this profile and client supports wireguard,
// we disable openvpn if the EDUVPN_PREFER_WG environment variable is set
// This is useful to force WireGuard if the profile supports both OpenVPN and WireGuard but the server still prefers OpenVPN