diff --git a/accounts/worker.go b/accounts/worker.go index 40c018f..3079d8c 100644 --- a/accounts/worker.go +++ b/accounts/worker.go @@ -1,24 +1,21 @@ package accounts import ( - "fmt" - "io" "net" "net/http" "net/url" - "os" - "strconv" "time" + "code.cloudfoundry.org/garden/client/connection" "code.cloudfoundry.org/lager" - v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/cli-runtime/pkg/genericclioptions" "k8s.io/client-go/rest" - "k8s.io/client-go/transport/spdy" - - "code.cloudfoundry.org/garden/client/connection" "k8s.io/client-go/tools/portforward" + "k8s.io/client-go/transport/spdy" + "k8s.io/kubectl/pkg/scheme" ) type GardenWorker struct { @@ -67,13 +64,12 @@ func (kgd *K8sGardenDialer) Dial() (net.Conn, error) { if err != nil { return nil, err } - dialer := spdy.NewDialer( + streamConn, _, err := spdy.NewDialer( upgrader, &http.Client{Transport: transport}, "POST", url, - ) - streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name) + ).Dial(portforward.PortForwardProtocolV1Name) // TODO why should this error? Test if err != nil { return nil, err @@ -81,18 +77,15 @@ func (kgd *K8sGardenDialer) Dial() (net.Conn, error) { headers := http.Header{} headers.Set(v1.StreamType, v1.StreamTypeData) headers.Set(v1.PortHeader, "7777") - - // TODO do we need this: - headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(0)) - + headers.Set(v1.PortForwardRequestIDHeader, "0") stream, err := streamConn.CreateStream(headers) - headers.Set(v1.StreamType, v1.StreamTypeError) - errorStream, err := streamConn.CreateStream(headers) // TODO why should this error? Test if err != nil { return nil, err } - go io.Copy(errorStream, os.Stdout) + + headers.Set(v1.StreamType, v1.StreamTypeError) + streamConn.CreateStream(headers) return &StreamConn{streamConn, stream}, nil } @@ -105,10 +98,26 @@ type K8sConnection interface { type systemK8sConnection struct { restConfig *rest.Config + namespace string + podName string } -func NewK8sConnection(restConfig *rest.Config) K8sConnection { - return &systemK8sConnection{restConfig} +func NewK8sConnection(namespace, podName string) (K8sConnection, error) { + restConfig, err := genericclioptions. + NewConfigFlags(true). + WithDeprecatedPasswordFlag(). + ToRESTConfig() + if err != nil { + return nil, err + } + restConfig.GroupVersion = &schema.GroupVersion{Group: "", Version: "v1"} + restConfig.NegotiatedSerializer = scheme.Codecs.WithoutConversion() + restConfig.APIPath = "/api" + return &systemK8sConnection{ + restConfig: restConfig, + namespace: namespace, + podName: podName, + }, nil } func (kc *systemK8sConnection) RESTConfig() *rest.Config { @@ -116,8 +125,6 @@ func (kc *systemK8sConnection) RESTConfig() *rest.Config { } func (kc *systemK8sConnection) URL() (*url.URL, error) { - namespace := "ci" - podName := "ci-worker-0" restClient, err := rest.RESTClientFor(kc.restConfig) if err != nil { return nil, err @@ -125,8 +132,8 @@ func (kc *systemK8sConnection) URL() (*url.URL, error) { return restClient. Post(). Resource("pods"). - Namespace(namespace). - Name(podName). + Namespace(kc.namespace). + Name(kc.podName). SubResource("portforward"). URL(), nil } @@ -148,41 +155,33 @@ func (sa *StreamAddr) String() string { } func (sc *StreamConn) Write(p []byte) (n int, err error) { - fmt.Println("Write", string(p)) return sc.stream.Write(p) } func (sc *StreamConn) Read(p []byte) (n int, err error) { - fmt.Println("Read", string(p)) return sc.stream.Read(p) } func (sc *StreamConn) Close() error { - fmt.Println("Close") return sc.conn.Close() } func (sc *StreamConn) LocalAddr() net.Addr { - fmt.Println("LocalAddr") return &StreamAddr{} } func (sc *StreamConn) RemoteAddr() net.Addr { - fmt.Println("RemoteAddr") return &StreamAddr{} } func (sc *StreamConn) SetDeadline(t time.Time) error { - fmt.Println("SetDeadline", t) return nil } func (sc *StreamConn) SetReadDeadline(t time.Time) error { - fmt.Println("SetReadDeadline", t) return nil } func (sc *StreamConn) SetWriteDeadline(t time.Time) error { - fmt.Println("SetReadDeadline", t) return nil } diff --git a/accounts/worker_test.go b/accounts/worker_test.go index 18c6bd1..022e1b0 100644 --- a/accounts/worker_test.go +++ b/accounts/worker_test.go @@ -18,16 +18,10 @@ import ( . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" "github.com/onsi/gomega/gstruct" - - // corev1 "k8s.io/api/core/v1" - // metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" - "k8s.io/client-go/tools/remotecommand" - - "k8s.io/kubernetes/pkg/kubelet/server/streaming" - + runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" cmdtesting "k8s.io/kubectl/pkg/cmd/testing" + "k8s.io/kubernetes/pkg/kubelet/server/streaming" ) var _ = Describe("Worker", func() { @@ -86,15 +80,14 @@ var _ = Describe("Worker", func() { io.Copy(buf, conn) return nil } - resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{ PodSandboxId: "foo", Port: []int32{7777}, }) Expect(err).NotTo(HaveOccurred()) - k8sConn := new(accountsfakes.FakeK8sConnection) testURL, err := url.Parse(resp.Url) Expect(err).NotTo(HaveOccurred()) + k8sConn := new(accountsfakes.FakeK8sConnection) k8sConn.URLReturns(testURL, nil) k8sConn.RESTConfigReturns(cmdtesting.DefaultClientConfig()) dialer := accounts.K8sGardenDialer{ @@ -102,14 +95,17 @@ var _ = Describe("Worker", func() { } conn, err := dialer.Dial() + Expect(err).NotTo(HaveOccurred()) conn.Write([]byte("hello world")) conn.Close() - Expect(buf).To(gbytes.Say("hello world")) + Eventually(buf).Should(gbytes.Say("hello world")) }) }) }) +// TODO we can probably use a counterfeiter fake for this + type fakeRuntime struct { execFunc func(string, []string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error attachFunc func(string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error @@ -137,6 +133,11 @@ type testStreamingServer struct { func newTestStreamingServer(streamIdleTimeout time.Duration) (s *testStreamingServer, err error) { s = &testStreamingServer{} s.testHTTPServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TODO we can probably make this smart enough to take a + // request to /api/v1/namespaces/ns/pods/pod/portforward + // and actually do a GetPortForward on the underlying + // StreamingServer, so that logic doesn't need to live in + // the body of the test. s.ServeHTTP(w, r) })) defer func() { diff --git a/main.go b/main.go index 7695598..b5c48b9 100644 --- a/main.go +++ b/main.go @@ -5,9 +5,7 @@ import ( "os" "strings" - "k8s.io/cli-runtime/pkg/genericclioptions" _ "k8s.io/client-go/plugin/pkg/client/auth" - cmdutil "k8s.io/kubectl/pkg/cmd/util" "github.com/concourse/concourse/fly/ui" "github.com/concourse/ctop/accounts" @@ -16,26 +14,37 @@ import ( flags "github.com/jessevdk/go-flags" ) +type Command struct { + Postgres flag.PostgresConfig `group:"PostgreSQL Configuration" namespace:"postgres"` + K8sNamespace string `long:"k8s-namespace"` + K8sPod string `long:"k8s-pod"` +} + func main() { - postgresConfig := flag.PostgresConfig{} - parser := flags.NewParser(&postgresConfig, flags.HelpFlag|flags.PassDoubleDash) + cmd := Command{} + parser := flags.NewParser(&cmd, flags.HelpFlag|flags.PassDoubleDash) + parser.NamespaceDelimiter = "-" _, err := parser.Parse() if err != nil { panic(err) } - // worker := accounts.NewLANWorker() - kubeConfigFlags := genericclioptions.NewConfigFlags(true).WithDeprecatedPasswordFlag() - f := cmdutil.NewFactory(kubeConfigFlags) - restConfig, err := f.ToRESTConfig() - if err != nil { - panic(err) + var dialer accounts.GardenDialer + if cmd.K8sNamespace != "" && cmd.K8sPod != "" { + k8sConn, err := accounts.NewK8sConnection( + cmd.K8sNamespace, + cmd.K8sPod, + ) + if err != nil { + panic(err) + } + dialer = &accounts.K8sGardenDialer{Conn: k8sConn} + } else { + dialer = &accounts.LANGardenDialer{} } worker := &accounts.GardenWorker{ - Dialer: &accounts.K8sGardenDialer{ - Conn: accounts.NewK8sConnection(restConfig), - }, + Dialer: dialer, } - accountant := accounts.NewDBAccountant(postgresConfig) + accountant := accounts.NewDBAccountant(cmd.Postgres) samples, err := accounts.Account(worker, accountant) if err != nil { panic(err)