diff --git a/AmplifyPlugins/Core/AWSPluginsCore/Model/Decorator/FilterDecorator.swift b/AmplifyPlugins/Core/AWSPluginsCore/Model/Decorator/FilterDecorator.swift index bd78aad37f..ad7cc770c9 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/Model/Decorator/FilterDecorator.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/Model/Decorator/FilterDecorator.swift @@ -37,6 +37,9 @@ public struct FilterDecorator: ModelBasedGraphQLDocumentDecorator { } else if case .query = document.operationType { inputs["filter"] = GraphQLDocumentInput(type: "Model\(modelName)FilterInput", value: .object(filter)) + } else if case .subscription = document.operationType { + inputs["filter"] = GraphQLDocumentInput(type: "ModelSubscription\(modelName)FilterInput", + value: .object(filter)) } return document.copy(inputs: inputs) diff --git a/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+AnyModelWithSync.swift b/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+AnyModelWithSync.swift index 44af846765..05b0a9dd7e 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+AnyModelWithSync.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/Model/GraphQLRequest/GraphQLRequest+AnyModelWithSync.swift @@ -37,10 +37,12 @@ protocol ModelSyncGraphQLRequestFactory { authType: AWSAuthorizationType?) -> GraphQLRequest static func subscription(to modelSchema: ModelSchema, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, authType: AWSAuthorizationType?) -> GraphQLRequest static func subscription(to modelSchema: ModelSchema, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, claims: IdentityClaimsDictionary, authType: AWSAuthorizationType?) -> GraphQLRequest @@ -94,16 +96,18 @@ extension GraphQLRequest: ModelSyncGraphQLRequestFactory { } public static func subscription(to modelType: Model.Type, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, authType: AWSAuthorizationType? = nil) -> GraphQLRequest { - subscription(to: modelType.schema, subscriptionType: subscriptionType, authType: authType) + subscription(to: modelType.schema, where: predicate, subscriptionType: subscriptionType, authType: authType) } public static func subscription(to modelType: Model.Type, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, claims: IdentityClaimsDictionary, authType: AWSAuthorizationType? = nil) -> GraphQLRequest { - subscription(to: modelType.schema, subscriptionType: subscriptionType, claims: claims, authType: authType) + subscription(to: modelType.schema, where: predicate, subscriptionType: subscriptionType, claims: claims, authType: authType) } public static func syncQuery(modelType: Model.Type, @@ -169,12 +173,18 @@ extension GraphQLRequest: ModelSyncGraphQLRequestFactory { } public static func subscription(to modelSchema: ModelSchema, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, authType: AWSAuthorizationType? = nil) -> GraphQLRequest { var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelSchema, operationType: .subscription, primaryKeysOnly: true) + + if let predicate = optimizePredicate(predicate) { + documentBuilder.add(decorator: FilterDecorator(filter: predicate.graphQLFilter(for: modelSchema))) + } + documentBuilder.add(decorator: DirectiveNameDecorator(type: subscriptionType)) documentBuilder.add(decorator: ConflictResolutionDecorator(graphQLType: .subscription, primaryKeysOnly: true)) documentBuilder.add(decorator: AuthRuleDecorator(.subscription(subscriptionType, nil), authType: authType)) @@ -190,6 +200,7 @@ extension GraphQLRequest: ModelSyncGraphQLRequestFactory { } public static func subscription(to modelSchema: ModelSchema, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, claims: IdentityClaimsDictionary, authType: AWSAuthorizationType? = nil) -> GraphQLRequest { @@ -197,6 +208,11 @@ extension GraphQLRequest: ModelSyncGraphQLRequestFactory { var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: modelSchema, operationType: .subscription, primaryKeysOnly: true) + + if let predicate = optimizePredicate(predicate) { + documentBuilder.add(decorator: FilterDecorator(filter: predicate.graphQLFilter(for: modelSchema))) + } + documentBuilder.add(decorator: DirectiveNameDecorator(type: subscriptionType)) documentBuilder.add(decorator: ConflictResolutionDecorator(graphQLType: .subscription, primaryKeysOnly: true)) documentBuilder.add(decorator: AuthRuleDecorator(.subscription(subscriptionType, claims), authType: authType)) diff --git a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift index d5dae69b37..0066b7ab3a 100644 --- a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift +++ b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift @@ -78,7 +78,8 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { self.onCreateValueListener = onCreateValueListener self.onCreateOperation = RetryableGraphQLSubscriptionOperation( requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor( - for: modelSchema, + for: modelSchema, + where: modelPredicate, subscriptionType: .onCreate, api: api, auth: auth, @@ -100,6 +101,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { self.onUpdateOperation = RetryableGraphQLSubscriptionOperation( requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor( for: modelSchema, + where: modelPredicate, subscriptionType: .onUpdate, api: api, auth: auth, @@ -120,7 +122,8 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { self.onDeleteValueListener = onDeleteValueListener self.onDeleteOperation = RetryableGraphQLSubscriptionOperation( requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor( - for: modelSchema, + for: modelSchema, + where: modelPredicate, subscriptionType: .onDelete, api: api, auth: auth, @@ -195,6 +198,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { } static func makeAPIRequest(for modelSchema: ModelSchema, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, api: APICategoryGraphQLBehaviorExtended, auth: AuthCategoryBehavior?, @@ -205,7 +209,8 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { let _ = auth, let tokenString = try? await awsAuthService.getUserPoolAccessToken(), case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) { - request = GraphQLRequest.subscription(to: modelSchema, + request = GraphQLRequest.subscription(to: modelSchema, + where: predicate, subscriptionType: subscriptionType, claims: claims, authType: authType) @@ -213,12 +218,14 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { let oidcAuthProvider = hasOIDCAuthProviderAvailable(api: api), let tokenString = try? await oidcAuthProvider.getLatestAuthToken(), case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) { - request = GraphQLRequest.subscription(to: modelSchema, + request = GraphQLRequest.subscription(to: modelSchema, + where: predicate, subscriptionType: subscriptionType, claims: claims, authType: authType) } else { request = GraphQLRequest.subscription(to: modelSchema, + where: predicate, subscriptionType: subscriptionType, authType: authType) } @@ -296,6 +303,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { // MARK: - IncomingAsyncSubscriptionEventPublisher + API request factory extension IncomingAsyncSubscriptionEventPublisher { static func apiRequestFactoryFor(for modelSchema: ModelSchema, + where predicate: QueryPredicate?, subscriptionType: GraphQLSubscriptionType, api: APICategoryGraphQLBehaviorExtended, auth: AuthCategoryBehavior?, @@ -303,7 +311,8 @@ extension IncomingAsyncSubscriptionEventPublisher { authTypeProvider: AWSAuthorizationTypeIterator) -> RetryableGraphQLOperation.RequestFactory { var authTypes = authTypeProvider return { - return await IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(for: modelSchema, + return await IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(for: modelSchema, + where: predicate, subscriptionType: subscriptionType, api: api, auth: auth,