diff options
| -rw-r--r-- | internal/discovery/discovery_test.go | 4 | ||||
| -rw-r--r-- | internal/test/server.go | 45 |
2 files changed, 45 insertions, 4 deletions
diff --git a/internal/discovery/discovery_test.go b/internal/discovery/discovery_test.go index 317aa50..495efd2 100644 --- a/internal/discovery/discovery_test.go +++ b/internal/discovery/discovery_test.go @@ -15,7 +15,7 @@ import ( // It setups up a file server using the 'test_files' directory func TestServers(t *testing.T) { handler := http.FileServer(http.Dir("test_files")) - s := test.NewServer(handler) + s := test.NewServer(handler, nil) DiscoURL = s.URL c, err := s.Client() if err != nil { @@ -57,7 +57,7 @@ func TestServers(t *testing.T) { // It setups up a file server using the 'test_files' directory func TestOrganizations(t *testing.T) { handler := http.FileServer(http.Dir("test_files")) - s := test.NewServer(handler) + s := test.NewServer(handler, nil) DiscoURL = s.URL c, err := s.Client() if err != nil { diff --git a/internal/test/server.go b/internal/test/server.go index 6c1b418..ee00656 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -4,6 +4,7 @@ package test import ( "crypto/tls" "crypto/x509" + "net" "net/http" "net/http/httptest" @@ -16,11 +17,51 @@ type Server struct { } // NewServer creates a new test server -func NewServer(handler http.Handler) *Server { - s := httptest.NewTLSServer(handler) +func NewServer(handler http.Handler, listener net.Listener) *Server { + if listener == nil { + s := httptest.NewTLSServer(handler) + return &Server{s} + } + + s := httptest.NewUnstartedServer(handler) + s.Listener.Close() + s.Listener = listener + s.StartTLS() return &Server{s} } +type HandlerPath struct { + Method string + Path string + Response string + ResponseHandler func(http.ResponseWriter, *http.Request) + ResponseCode int +} + +func (hp *HandlerPath) HandlerFunc() func(http.ResponseWriter, *http.Request) { + if hp.ResponseHandler != nil { + return hp.ResponseHandler + } + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != hp.Method { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.WriteHeader(hp.ResponseCode) + w.Write([]byte(hp.Response)) + } +} + +// NewServerWithHandles creates a new test servers with path and responses +func NewServerWithHandles(hps []HandlerPath, listener net.Listener) *Server { + mux := http.NewServeMux() + for _, hp := range hps { + hp := hp + mux.HandleFunc(hp.Path, hp.HandlerFunc()) + } + return NewServer(mux, listener) +} + // Client returns a test client that trusts the HTTPS certificates func (srv *Server) Client() (*httpw.Client, error) { // Get the certs from the test server |
