diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-12-20 15:35:44 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-12-21 18:28:00 +0100 |
| commit | 6981666c6d8f639a1ff9c09a3bc08769e19928af (patch) | |
| tree | bdb94d76a7fb6a08ef200e9bbbbd5fff1d6b134c | |
| parent | 697dfed1f9f5d2916889a81a7a64bd1158caf2d2 (diff) | |
Failover: Initial implementation
| -rw-r--r-- | client/client.go | 4 | ||||
| -rw-r--r-- | client/fsm.go | 4 | ||||
| -rw-r--r-- | client/server.go | 39 | ||||
| -rw-r--r-- | exports/exports.go | 43 | ||||
| -rw-r--r-- | go.mod | 1 | ||||
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | internal/failover/failover.go | 21 | ||||
| -rw-r--r-- | internal/failover/monitor.go | 106 | ||||
| -rw-r--r-- | internal/failover/ping.go | 51 | ||||
| -rw-r--r-- | internal/log/log.go | 1 | ||||
| -rw-r--r-- | internal/server/base.go | 2 | ||||
| -rw-r--r-- | internal/server/profile.go | 4 | ||||
| -rw-r--r-- | internal/server/server.go | 8 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 16 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 1 |
16 files changed, 293 insertions, 13 deletions
diff --git a/client/client.go b/client/client.go index a31334a..b37c506 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,7 @@ import ( "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" + "github.com/eduvpn/eduvpn-common/internal/failover" "github.com/eduvpn/eduvpn-common/internal/fsm" "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/server" @@ -62,6 +63,9 @@ type Client struct { // Whether to enable debugging Debug bool `json:"-"` + + // The Failover monitor for the current VPN connection + Failover *failover.DroppedConMon } // Register initializes the clientwith the following parameters: diff --git a/client/fsm.go b/client/fsm.go index c156fba..76dee05 100644 --- a/client/fsm.go +++ b/client/fsm.go @@ -212,7 +212,7 @@ func (c *Client) SetSearchServer() error { return err } - // TODO(jwijenbergh): Should we handle `false` returned value here? + //TODO(jwijenbergh): Should we handle `false` returned value here? c.FSM.GoTransition(StateSearchServer) return nil } @@ -325,7 +325,7 @@ func (c *Client) SetDisconnected(cleanup bool) error { func (c *Client) goBackInternal() { err := c.GoBack() if err != nil { - // TODO(jwijenbergh): Bit suspicious - logging level INFO, yet stacktrace logged. + //TODO(jwijenbergh): Bit suspicious - logging level INFO, yet stacktrace logged. c.Logger.Infof("failed going back: %s\nstacktrace:\n%s", err.Error(), err.(*errors.Error).ErrorStack()) } } diff --git a/client/server.go b/client/server.go index 4c00986..6802a47 100644 --- a/client/server.go +++ b/client/server.go @@ -1,6 +1,7 @@ package client import ( + "github.com/eduvpn/eduvpn-common/internal/failover" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" "github.com/eduvpn/eduvpn-common/internal/util" @@ -84,8 +85,8 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e // Save the config if err = c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. + //TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... + //TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. c.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", err.Error(), err.(*errors.Error).ErrorStack()) } @@ -584,3 +585,37 @@ func (c *Client) SetProfileID(profileID string) (err error) { b.Profiles.Current = profileID return nil } + +func (c *Client) StartFailover(gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { + currentServer, currentServerErr := c.Servers.GetCurrentServer() + if currentServerErr != nil { + return false, currentServerErr + } + + // Check if the current profile supports OpenVPN + profile, profileErr := server.CurrentProfile(currentServer) + if profileErr != nil { + return false, profileErr + } + + if !profile.SupportsOpenVPN() { + return false, errors.New("Profile does not support OpenVPN fallback") + } + + monitor, monitorErr := failover.New(readRxBytes) + if monitorErr != nil { + return false, monitorErr + } + // Initialize the client's monitor + c.Failover = monitor + + return c.Failover.Start(gateway, wgMTU) +} + +func (c *Client) CancelFailover() error { + if c.Failover == nil { + return errors.New("No failover process") + } + c.Failover.Cancel() + return nil +} diff --git a/exports/exports.go b/exports/exports.go index 87ce331..e374661 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -4,8 +4,13 @@ package main #include <stdlib.h> #include "error.h" +typedef long long int (*ReadRxBytes)(); typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data); +static long long int get_read_rx_bytes(ReadRxBytes read) +{ + return read(); +} static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) { return callback(name, oldstate, newstate, data); @@ -441,6 +446,44 @@ func SetSupportWireguard(name *C.char, support C.int) *C.error { return nil } +//export StartFailover +func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.error) { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return C.int(0), getError(stateErr) + } + dropped, droppedErr := state.StartFailover(C.GoString(gateway), int(mtu), func() (int64, error) { + rxBytes := int64(C.get_read_rx_bytes(readRxBytes)) + if rxBytes == -1 { + return 0, errors.New("client gave an invalid rx bytes value") + } + return rxBytes, nil + }) + if droppedErr != nil { + return C.int(0), getError(droppedErr) + } + droppedC := C.int(0) + if dropped { + droppedC = C.int(1) + } + return droppedC, nil +} + +//export CancelFailover +func CancelFailover(name *C.char) *C.error { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return getError(stateErr) + } + cancelErr := state.CancelFailover() + if cancelErr != nil { + return getError(cancelErr) + } + return nil +} + //export FreeString func FreeString(addr *C.char) { C.free(unsafe.Pointer(addr)) @@ -10,5 +10,6 @@ require ( require ( github.com/go-errors/errors v1.4.2 golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect + golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect ) @@ -23,6 +23,7 @@ golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 h1:6mzvA99KwZxbOrxww4EvWVQUnN1+xEu9tafK5ZxkYeA= golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -46,7 +47,6 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0= diff --git a/internal/failover/failover.go b/internal/failover/failover.go new file mode 100644 index 0000000..f239eeb --- /dev/null +++ b/internal/failover/failover.go @@ -0,0 +1,21 @@ +package failover + +import "time" + +const ( + // Send a ping every 2 seconds to the gateway + pInterval time.Duration = 2 * time.Second + + // pAlive is how many pings we need to have sent to check if the connection is alive + pAlive int = 3 + + // pDropped is how many pings we need to have sent to check if the connection is dropped + pDropped int = 5 +) + +// New creates a failover monitor for the gateway and the rx bytes function reader +// This is a simple wrapper over `NewDroppedMonitor` to create one with the default settings +// If this function returns True, the connection is dropped. False means it has exited and we don't know for sure if it's dropped or not +func New(readRxBytes func() (int64, error)) (*DroppedConMon, error) { + return NewDroppedMonitor(pInterval, pAlive, pDropped, readRxBytes) +} diff --git a/internal/failover/monitor.go b/internal/failover/monitor.go new file mode 100644 index 0000000..d14fb9e --- /dev/null +++ b/internal/failover/monitor.go @@ -0,0 +1,106 @@ +package failover + +import ( + "context" + "time" + + "github.com/go-errors/errors" +) + +// The DroppedConMon is a connection monitor that checks for an increase in rx bytes in certain intervals +type DroppedConMon struct { + // pInterval means how the interval in which to send pings + pInterval time.Duration + // pAlive means how many pings need to be send before checking if the connection is alive + pAlive int + // pDropped means how many pings need to be send before checking if the connection is dropped + pDropped int + // The function that reads Rx bytes + // If this function returns an error, the monitor exits + readRxBytes func() (int64, error) + // The cancel context + // This is used to cancel the dropped connection monitor + cancel context.CancelFunc +} + +func NewDroppedMonitor(pingInterval time.Duration, pAlive int, pDropped int, readRxBytes func() (int64, error)) (*DroppedConMon, error) { + if pAlive >= pDropped { + return nil, errors.New("pAlive must be smaller than pDropped") + } + return &DroppedConMon{pInterval: pingInterval, pAlive: pAlive, pDropped: pDropped, readRxBytes: readRxBytes}, nil +} + +// Dropped checks whether or not the connection is 'dropped' +// In other words, it checks if rx bytes has increased +func (m *DroppedConMon) dropped(startBytes int64) (bool, error) { + b, err := m.readRxBytes() + if err != nil { + return false, err + } + return b <= startBytes, nil +} + +// Start starts ticking every ping interval and check if the connection is dropped or alive +// This does not check Rx bytes every tick, but rather when pAlive or pDropped is reached +// It returns an error if there was an invalid input or a ping was failed to be sent +func (m *DroppedConMon) Start(gateway string, mtuSize int) (bool, error) { + if mtuSize <= 0 { + return false, errors.New("invalid mtu size given") + } + + // Create a context and save the cancel function + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + defer m.cancel() + + // Create a ping struct with our mtu size + p, err := NewPinger(mtuSize) + if err != nil { + return false, err + } + + // Read the start Rx bytes + b, err := m.readRxBytes() + if err != nil { + return false, err + } + + // Create a new ticker that executes our ping function every 'interval' seconds + // It starts immediately and stops when we reach the end + ticker := time.NewTicker(m.pInterval) + defer ticker.Stop() + + // Loop until the max drop counter + // We begin with 1 as this is used as the sequence number for ping + for s := 1; s <= m.pDropped; s++ { + // Send a ping and return if an error occurs + if err := p.Send(gateway, s); err != nil { + return false, err + } + + // Early alive check + // If not dropped, return + if s == m.pAlive { + if d, err := m.dropped(b); !d { + return false, err + } + } + // Wait for the next tick to continue + select { + case <-ticker.C: + continue + case <-ctx.Done(): + return false, errors.New("failover was cancelled") + } + } + + // Dropped check if we have not returned early + return m.dropped(b) +} + +// Cancel cancels the dropped connection failover monitor if there is one +func (m *DroppedConMon) Cancel() { + if m.cancel != nil { + m.cancel() + } +} diff --git a/internal/failover/ping.go b/internal/failover/ping.go new file mode 100644 index 0000000..2ffedd4 --- /dev/null +++ b/internal/failover/ping.go @@ -0,0 +1,51 @@ +package failover + +import ( + "fmt" + "net" + "os" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + + "github.com/go-errors/errors" +) + +// mtuOverhead defines the total MTU overhead for an ICMP ECHO message: 20 bytes IP header + 8 bytes ICMP header +var mtuOverhead = 28 + +type Pinger struct { + listener net.PacketConn + buffer []byte +} + +func NewPinger(size int) (*Pinger, error) { + l, err := icmp.ListenPacket("udp4", "0.0.0.0") + if err != nil { + return nil, errors.WrapPrefix(err, "failed creating ping", 0) + } + return &Pinger{listener: l, buffer: make([]byte, size-mtuOverhead)}, nil +} + +func (p Pinger) Send(gateway string, seq int) error { + errorMessage := fmt.Sprintf("failed sending ping, seq %d", seq) + // Make a new ICMP message + m := icmp.Message{ + Type: ipv4.ICMPTypeEcho, Code: 0, + Body: &icmp.Echo{ + ID: os.Getpid() & 0xffff, Seq: seq, + Data: p.buffer, + }, + } + // Marshal the message to bytes + b, err := m.Marshal(nil) + if err != nil { + return errors.WrapPrefix(err, errorMessage, 0) + } + // And send it to the gateway IP! + _, err = p.listener.WriteTo(b, &net.UDPAddr{IP: net.ParseIP(gateway)}) + if err != nil { + return errors.WrapPrefix(err, errorMessage, 0) + } + return nil +} diff --git a/internal/log/log.go b/internal/log/log.go index be4730f..64d9fb8 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -9,7 +9,6 @@ import ( "path" "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/util" "github.com/go-errors/errors" ) diff --git a/internal/server/base.go b/internal/server/base.go index 6eb305b..dd15aff 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -30,7 +30,7 @@ func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo { for _, p := range b.Profiles.Info.ProfileList { // Not a valid profile because it does not support openvpn // Also the client does not support wireguard - if !p.supportsOpenVPN() && !wireguardSupport { + if !p.SupportsOpenVPN() && !wireguardSupport { continue } valid = append(valid, p) diff --git a/internal/server/profile.go b/internal/server/profile.go index 97781e4..d981421 100644 --- a/internal/server/profile.go +++ b/internal/server/profile.go @@ -35,10 +35,10 @@ func (profile *Profile) supportsProtocol(protocol string) bool { return false } -func (profile *Profile) supportsWireguard() bool { +func (profile *Profile) SupportsWireguard() bool { return profile.supportsProtocol("wireguard") } -func (profile *Profile) supportsOpenVPN() bool { +func (profile *Profile) SupportsOpenVPN() bool { return profile.supportsProtocol("openvpn") } diff --git a/internal/server/server.go b/internal/server/server.go index 9354883..78f6472 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "os" "time" "github.com/eduvpn/eduvpn-common/internal/oauth" @@ -219,7 +220,7 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { return false, err } // Profile does not support OpenVPN but the client also doesn't support WireGuard - if !p.supportsOpenVPN() && !wireguardSupport { + if !p.SupportsOpenVPN() && !wireguardSupport { return false, nil } return true, nil @@ -242,8 +243,9 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (string, strin return "", "", err } - ovpn := p.supportsOpenVPN() - wg := p.supportsWireguard() && wireguardSupport + ovpn := p.SupportsOpenVPN() + wg := p.SupportsWireguard() && wireguardSupport + // If we don't prefer TCP and this profile and client supports wireguard, // we disable openvpn if the EDUVPN_PREFER_WG environment variable is set // This is useful to force WireGuard if the profile supports both OpenVPN and WireGuard but the server still prefers OpenVPN diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 3de3de5..1090619 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -7,6 +7,7 @@ from eduvpn_common import __version__ from eduvpn_common.types import ( ConfigError, DataError, + ReadRxBytes, VPNStateChange, ) @@ -151,3 +152,5 @@ def initialize_functions(lib: CDLL) -> None: c_int, ], c_void_p lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int + lib.StartFailover.argtypes, lib.StartFailover.restype = [c_char_p, c_char_p, c_int, ReadRxBytes], DataError + lib.CancelFailover.argtypes, lib.CancelFailover.restype = [c_char_p], c_void_p diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 20d646f..3cb45e1 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -7,7 +7,7 @@ from eduvpn_common.event import EventHandler from eduvpn_common.loader import initialize_functions, load_lib from eduvpn_common.server import Profiles, Server, get_transition_server, get_servers from eduvpn_common.state import State, StateType -from eduvpn_common.types import VPNStateChange, decode_res, encode_args, get_data_error +from eduvpn_common.types import ReadRxBytes, VPNStateChange, decode_res, encode_args, get_data_error, get_bool class EduVPN(object): @@ -502,6 +502,20 @@ class EduVPN(object): return servers + def start_failover(self, gateway: str, wg_mtu: int, readrxbytes: ReadRxBytes) -> bool: + dropped, dropped_err = self.go_function( + self.lib.StartFailover, gateway, wg_mtu, readrxbytes, + decode_func=lambda lib, x: get_data_error(lib, x, get_bool), + ) + if dropped_err: + raise dropped_err + return dropped + + def cancel_failover(self): + cancel_err = self.go_function(self.lib.CancelFailover) + if cancel_err: + raise cancel_err + eduvpn_objects: Dict[str, EduVPN] = {} diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index 07a02d3..7e3ce9a 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -166,6 +166,7 @@ class ConfigError(Structure): # The type for a Go state change callback VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p) +ReadRxBytes = CFUNCTYPE(c_ulonglong) def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: |
