From b928e327e0653bd94f26ac1d78179958e28f0265 Mon Sep 17 00:00:00 2001
From: Steven Weathers <steven@weathers.me>
Date: Sun, 22 Sep 2024 21:27:50 -0400
Subject: [PATCH] Add unit tests for user subscribed middleware functions

---
 internal/http/middleware_test.go              | 257 ++++++++++++++++++
 internal/http/types.go                        |  13 +-
 internal/oauth/oauth.go                       |   2 +-
 internal/oauth/type.go                        |   6 +-
 internal/webhook/subscription/subscription.go |  10 +-
 thunderdome/subscription.go                   |  12 -
 6 files changed, 283 insertions(+), 17 deletions(-)

diff --git a/internal/http/middleware_test.go b/internal/http/middleware_test.go
index 586ebc74..bf2df9df 100644
--- a/internal/http/middleware_test.go
+++ b/internal/http/middleware_test.go
@@ -1405,3 +1405,260 @@ func TestOrgUserOnly(t *testing.T) {
 		})
 	}
 }
+
+func TestSubscribedUserOnly(t *testing.T) {
+	tests := []struct {
+		name                 string
+		userID               string
+		userType             string
+		subscriptionsEnabled bool
+		expectedStatus       int
+		mockSetup            func(mockSubDataSvc *MockSubscriptionDataService)
+	}{
+		{
+			name:                 "Subscriptions Disabled",
+			userID:               "123e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			subscriptionsEnabled: false,
+			expectedStatus:       http.StatusOK,
+			mockSetup:            func(mockSubDataSvc *MockSubscriptionDataService) {},
+		},
+		{
+			name:                 "Admin User",
+			userID:               "223e4567-e89b-12d3-a456-426614174000",
+			userType:             thunderdome.AdminUserType,
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusOK,
+			mockSetup:            func(mockSubDataSvc *MockSubscriptionDataService) {},
+		},
+		{
+			name:                 "Subscribed Regular User",
+			userID:               "323e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusOK,
+			mockSetup: func(mockSubDataSvc *MockSubscriptionDataService) {
+				mockSubDataSvc.On("CheckActiveSubscriber", mock.Anything, "323e4567-e89b-12d3-a456-426614174000").Return(nil)
+			},
+		},
+		{
+			name:                 "Unsubscribed Regular User",
+			userID:               "423e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusForbidden,
+			mockSetup: func(mockSubDataSvc *MockSubscriptionDataService) {
+				mockSubDataSvc.On("CheckActiveSubscriber", mock.Anything, "423e4567-e89b-12d3-a456-426614174000").Return(errors.New("not subscribed"))
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			// Mock SubscriptionDataService
+			mockSubDataSvc := new(MockSubscriptionDataService)
+
+			// Create a new service with the mock and config
+			s := &Service{
+				SubscriptionDataSvc: mockSubDataSvc,
+				Config: &Config{
+					SubscriptionsEnabled: tt.subscriptionsEnabled,
+				},
+			}
+
+			// Define a dummy handler for testing
+			dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				w.WriteHeader(http.StatusOK)
+			})
+
+			// Setup mock expectations
+			tt.mockSetup(mockSubDataSvc)
+
+			// Create a new request
+			req, err := http.NewRequest("GET", "/test", nil)
+			assert.NoError(t, err)
+
+			// Create a new response recorder
+			rr := httptest.NewRecorder()
+
+			// Set up the context with user information
+			ctx := context.WithValue(req.Context(), contextKeyUserID, tt.userID)
+			ctx = context.WithValue(ctx, contextKeyUserType, tt.userType)
+			req = req.WithContext(ctx)
+
+			// Call the middleware
+			handler := s.subscribedUserOnly(dummyHandler)
+			handler.ServeHTTP(rr, req)
+
+			// Check the status code
+			assert.Equal(t, tt.expectedStatus, rr.Code)
+
+			// Clear mock expectations for the next test
+			mockSubDataSvc.AssertExpectations(t)
+		})
+	}
+}
+
+func TestSubscribedEntityUserOnly(t *testing.T) {
+	tests := []struct {
+		name                 string
+		userID               string
+		userType             string
+		entityUserID         string
+		subscriptionsEnabled bool
+		expectedStatus       int
+		mockSetup            func(mockSubDataSvc *MockSubscriptionDataService)
+	}{
+		{
+			name:                 "Admin User",
+			userID:               "123e4567-e89b-12d3-a456-426614174000",
+			userType:             thunderdome.AdminUserType,
+			entityUserID:         "223e4567-e89b-12d3-a456-426614174000",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusOK,
+			mockSetup:            func(mockSubDataSvc *MockSubscriptionDataService) {},
+		},
+		{
+			name:                 "Matching User ID",
+			userID:               "323e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			entityUserID:         "323e4567-e89b-12d3-a456-426614174000",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusOK,
+			mockSetup: func(mockSubDataSvc *MockSubscriptionDataService) {
+				mockSubDataSvc.On("CheckActiveSubscriber", mock.Anything, "323e4567-e89b-12d3-a456-426614174000").Return(nil)
+			},
+		},
+		{
+			name:                 "Non-Matching User ID",
+			userID:               "423e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			entityUserID:         "523e4567-e89b-12d3-a456-426614174000",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusForbidden,
+			mockSetup:            func(mockSubDataSvc *MockSubscriptionDataService) {},
+		},
+		{
+			name:                 "Subscriptions Disabled",
+			userID:               "623e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			entityUserID:         "623e4567-e89b-12d3-a456-426614174000",
+			subscriptionsEnabled: false,
+			expectedStatus:       http.StatusOK,
+			mockSetup:            func(mockSubDataSvc *MockSubscriptionDataService) {},
+		},
+		{
+			name:                 "Unsubscribed User",
+			userID:               "723e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			entityUserID:         "723e4567-e89b-12d3-a456-426614174000",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusForbidden,
+			mockSetup: func(mockSubDataSvc *MockSubscriptionDataService) {
+				mockSubDataSvc.On("CheckActiveSubscriber", mock.Anything, "723e4567-e89b-12d3-a456-426614174000").Return(errors.New("not subscribed"))
+			},
+		},
+		{
+			name:                 "Invalid Entity User ID",
+			userID:               "823e4567-e89b-12d3-a456-426614174000",
+			userType:             "REGULAR",
+			entityUserID:         "invalid-user-id",
+			subscriptionsEnabled: true,
+			expectedStatus:       http.StatusBadRequest,
+			mockSetup:            func(mockSubDataSvc *MockSubscriptionDataService) {},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			// Mock SubscriptionDataService
+			mockSubDataSvc := new(MockSubscriptionDataService)
+
+			// Create a new service with the mock and config
+			s := &Service{
+				SubscriptionDataSvc: mockSubDataSvc,
+				Config: &Config{
+					SubscriptionsEnabled: tt.subscriptionsEnabled,
+				},
+			}
+
+			// Define a dummy handler for testing
+			dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				w.WriteHeader(http.StatusOK)
+			})
+
+			// Setup mock expectations
+			tt.mockSetup(mockSubDataSvc)
+
+			// Create a new request
+			req, err := http.NewRequest("GET", "/users/"+tt.entityUserID, nil)
+			assert.NoError(t, err)
+
+			// Create a new response recorder
+			rr := httptest.NewRecorder()
+
+			// Set up the context with user information
+			ctx := context.WithValue(req.Context(), contextKeyUserID, tt.userID)
+			ctx = context.WithValue(ctx, contextKeyUserType, tt.userType)
+			req = req.WithContext(ctx)
+
+			// Set up router with vars
+			router := mux.NewRouter()
+			router.HandleFunc("/users/{userId}", s.subscribedEntityUserOnly(dummyHandler))
+
+			// Serve the request
+			router.ServeHTTP(rr, req)
+
+			// Check the status code
+			assert.Equal(t, tt.expectedStatus, rr.Code)
+
+			// Clear mock expectations for the next test
+			mockSubDataSvc.AssertExpectations(t)
+		})
+	}
+}
+
+// MockSubscriptionDataService is a mock of SubscriptionDataService
+type MockSubscriptionDataService struct {
+	mock.Mock
+}
+
+func (m *MockSubscriptionDataService) GetSubscriptionByID(ctx context.Context, id string) (thunderdome.Subscription, error) {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) GetSubscriptionBySubscriptionID(ctx context.Context, subscriptionId string) (thunderdome.Subscription, error) {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) GetActiveSubscriptionsByUserID(ctx context.Context, userId string) ([]thunderdome.Subscription, error) {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) CreateSubscription(ctx context.Context, subscription thunderdome.Subscription) (thunderdome.Subscription, error) {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) UpdateSubscription(ctx context.Context, id string, sub thunderdome.Subscription) (thunderdome.Subscription, error) {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) GetSubscriptions(ctx context.Context, Limit int, Offset int) ([]thunderdome.Subscription, int, error) {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) DeleteSubscription(ctx context.Context, id string) error {
+	//TODO implement me
+	panic("implement me")
+}
+
+func (m *MockSubscriptionDataService) CheckActiveSubscriber(ctx context.Context, userID string) error {
+	args := m.Called(ctx, userID)
+	return args.Error(0)
+}
diff --git a/internal/http/types.go b/internal/http/types.go
index 942252cb..9282ff49 100644
--- a/internal/http/types.go
+++ b/internal/http/types.go
@@ -121,7 +121,7 @@ type Service struct {
 	OrganizationDataSvc  OrganizationDataSvc
 	AdminDataSvc         AdminDataSvc
 	JiraDataSvc          JiraDataSvc
-	SubscriptionDataSvc  thunderdome.SubscriptionDataSvc
+	SubscriptionDataSvc  SubscriptionDataSvc
 	RetroTemplateDataSvc thunderdome.RetroTemplateDataSvc
 	SubscriptionSvc      *subscription.Service
 }
@@ -291,3 +291,14 @@ type TeamDataSvc interface {
 	GetTeamMetrics(ctx context.Context, teamID string) (*thunderdome.TeamMetrics, error)
 	TeamUserRoles(ctx context.Context, UserID string, TeamID string) (*thunderdome.UserTeamRoleInfo, error)
 }
+
+type SubscriptionDataSvc interface {
+	CheckActiveSubscriber(ctx context.Context, userId string) error
+	GetSubscriptionByID(ctx context.Context, id string) (thunderdome.Subscription, error)
+	GetSubscriptionBySubscriptionID(ctx context.Context, subscriptionId string) (thunderdome.Subscription, error)
+	GetActiveSubscriptionsByUserID(ctx context.Context, userId string) ([]thunderdome.Subscription, error)
+	CreateSubscription(ctx context.Context, subscription thunderdome.Subscription) (thunderdome.Subscription, error)
+	UpdateSubscription(ctx context.Context, id string, sub thunderdome.Subscription) (thunderdome.Subscription, error)
+	GetSubscriptions(ctx context.Context, Limit int, Offset int) ([]thunderdome.Subscription, int, error)
+	DeleteSubscription(ctx context.Context, id string) error
+}
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index ab5f20f0..c7961476 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -19,7 +19,7 @@ func New(
 	cookie CookieManager,
 	logger *otelzap.Logger,
 	authDataSvc AuthDataSvc,
-	subscriptionDataSvc thunderdome.SubscriptionDataSvc,
+	subscriptionDataSvc SubscriptionDataSvc,
 	ctx context.Context,
 ) (*Service, error) {
 	s := Service{
diff --git a/internal/oauth/type.go b/internal/oauth/type.go
index c9bfb5d4..8373ccb7 100644
--- a/internal/oauth/type.go
+++ b/internal/oauth/type.go
@@ -37,6 +37,10 @@ type AuthDataSvc interface {
 	OauthAuthUser(ctx context.Context, provider string, sub string, email string, emailVerified bool, name string, pictureUrl string) (*thunderdome.User, string, error)
 }
 
+type SubscriptionDataSvc interface {
+	CheckActiveSubscriber(ctx context.Context, userId string) error
+}
+
 type Service struct {
 	config              Config
 	cookie              CookieManager
@@ -44,5 +48,5 @@ type Service struct {
 	logger              *otelzap.Logger
 	verifier            *oidc.IDTokenVerifier
 	authDataSvc         AuthDataSvc
-	subscriptionDataSvc thunderdome.SubscriptionDataSvc
+	subscriptionDataSvc SubscriptionDataSvc
 }
diff --git a/internal/webhook/subscription/subscription.go b/internal/webhook/subscription/subscription.go
index e37f39d5..7d1631f5 100644
--- a/internal/webhook/subscription/subscription.go
+++ b/internal/webhook/subscription/subscription.go
@@ -26,10 +26,16 @@ type Config struct {
 	WebhookSecret string
 }
 
+type DataSvc interface {
+	GetSubscriptionBySubscriptionID(ctx context.Context, subscriptionId string) (thunderdome.Subscription, error)
+	CreateSubscription(ctx context.Context, subscription thunderdome.Subscription) (thunderdome.Subscription, error)
+	UpdateSubscription(ctx context.Context, id string, sub thunderdome.Subscription) (thunderdome.Subscription, error)
+}
+
 type Service struct {
 	config      Config
 	logger      *otelzap.Logger
-	dataSvc     thunderdome.SubscriptionDataSvc
+	dataSvc     DataSvc
 	emailSvc    thunderdome.EmailService
 	userDataSvc thunderdome.UserDataSvc
 }
@@ -37,7 +43,7 @@ type Service struct {
 func New(
 	config Config,
 	logger *otelzap.Logger,
-	dataSvc thunderdome.SubscriptionDataSvc,
+	dataSvc DataSvc,
 	emailSvc thunderdome.EmailService,
 	userDataSvc thunderdome.UserDataSvc,
 ) *Service {
diff --git a/thunderdome/subscription.go b/thunderdome/subscription.go
index b217c7ec..ef780b9e 100644
--- a/thunderdome/subscription.go
+++ b/thunderdome/subscription.go
@@ -1,7 +1,6 @@
 package thunderdome
 
 import (
-	"context"
 	"time"
 )
 
@@ -19,14 +18,3 @@ type Subscription struct {
 	UpdatedDate    time.Time `json:"updated_date"`
 	User           User      `json:"user"`
 }
-
-type SubscriptionDataSvc interface {
-	CheckActiveSubscriber(ctx context.Context, userId string) error
-	GetSubscriptionByID(ctx context.Context, id string) (Subscription, error)
-	GetSubscriptionBySubscriptionID(ctx context.Context, subscriptionId string) (Subscription, error)
-	GetActiveSubscriptionsByUserID(ctx context.Context, userId string) ([]Subscription, error)
-	CreateSubscription(ctx context.Context, subscription Subscription) (Subscription, error)
-	UpdateSubscription(ctx context.Context, id string, sub Subscription) (Subscription, error)
-	GetSubscriptions(ctx context.Context, Limit int, Offset int) ([]Subscription, int, error)
-	DeleteSubscription(ctx context.Context, id string) error
-}