Skip to content

Commit

Permalink
Fix method signature code gen to adhere to spec (#99)
Browse files Browse the repository at this point in the history
* fix method signature code gen to adhere to spec

* cleanup signature variant parsing impl
  • Loading branch information
marcoferrer authored Dec 27, 2019
1 parent 3a75d61 commit e223ead
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ class GrpcStubBuilder(val context: GeneratorContext){
MethodDescriptor.MethodType.UNARY -> {
addFunction(buildUnaryMethod(method))
addFunction(buildUnaryLambdaOverload(method))
buildUnaryMethodSigOverload(method)?.let { addFunction(it) }
addFunctions(buildUnaryMethodSigOverload(method))
}

MethodDescriptor.MethodType.SERVER_STREAMING -> {
addFunction(buildServerStreamingMethod(method))
addFunction(buildServerStreamingLambdaOverload(method))
buildServerStreamingMethodSigOverload(method)?.let { addFunction(it) }
addFunctions(buildServerStreamingMethodSigOverload(method))
}

MethodDescriptor.MethodType.CLIENT_STREAMING ->
Expand Down Expand Up @@ -190,27 +190,29 @@ class GrpcStubBuilder(val context: GeneratorContext){

// Method signature overloads

private fun buildUnaryMethodSigOverload(protoMethod: ProtoMethod): FunSpec? = with(protoMethod){
if(methodSignatureFields.isEmpty())
null else FunSpec.builder(functionName)
.addKdoc(attachedComments)
.addModifiers(KModifier.SUSPEND)
.returns(responseClassName)
.addMethodSignatureParameter(methodSignatureFields,context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildUnaryMethodSigOverload(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.addModifiers(KModifier.SUSPEND)
.returns(responseClassName)
.addMethodSignatureParameter(variant,context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)",functionName)
.build()
}
}

private fun buildServerStreamingMethodSigOverload(protoMethod: ProtoMethod): FunSpec? = with(protoMethod){
if(methodSignatureFields.isEmpty())
null else FunSpec.builder(functionName)
.addKdoc(attachedComments)
.returns(CommonClassNames.receiveChannel.parameterizedBy(responseClassName))
.addMethodSignatureParameter(methodSignatureFields,context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildServerStreamingMethodSigOverload(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.returns(CommonClassNames.receiveChannel.parameterizedBy(responseClassName))
.addMethodSignatureParameter(variant,context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)",functionName)
.build()
}
}

// Stub companion object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ServerStreamingStubExtsBuilder(val context: GeneratorContext){
val funSpecs = mutableListOf<FunSpec>()

// Add method signature exts
if(method.methodSignatureFields.isNotEmpty()){
if(method.methodSignatureVariants.isNotEmpty()){
funSpecs += buildAsyncMethodSigExt(method)
funSpecs += buildBlockingMethodSigExt(method)
}
Expand All @@ -57,7 +57,7 @@ class ServerStreamingStubExtsBuilder(val context: GeneratorContext){
}

private fun addCoroutineStubExts(funSpecs: MutableList<FunSpec>, method: ProtoMethod){
if(method.methodSignatureFields.isNotEmpty()) {
if(method.methodSignatureVariants.isNotEmpty()) {
funSpecs += buildCoroutineMethodSigExt(method)
}

Expand All @@ -84,38 +84,44 @@ class ServerStreamingStubExtsBuilder(val context: GeneratorContext){

// Method Signature Extensions

private fun buildAsyncMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.asyncStubClassName)
.returns(UNIT)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.addResponseObserverParameter(responseClassName)
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("%N(request, responseObserver)",functionName)
.build()
private fun buildAsyncMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.asyncStubClassName)
.returns(UNIT)
.addMethodSignatureParameter(variant, context.schema)
.addResponseObserverParameter(responseClassName)
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("%N(request, responseObserver)", functionName)
.build()
}
}

private fun buildCoroutineMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.asyncStubClassName)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.returns(CommonClassNames.receiveChannel.parameterizedBy(responseClassName))
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildCoroutineMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.asyncStubClassName)
.addMethodSignatureParameter(variant, context.schema)
.returns(CommonClassNames.receiveChannel.parameterizedBy(responseClassName))
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)", functionName)
.build()
}
}

private fun buildBlockingMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.blockingStubClassName)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.returns(Iterator::class.asClassName().parameterizedBy(responseClassName))
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildBlockingMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.blockingStubClassName)
.addMethodSignatureParameter(variant, context.schema)
.returns(Iterator::class.asClassName().parameterizedBy(responseClassName))
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)", functionName)
.build()
}
}

// Lambda Builder Extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class UnaryStubExtsBuilder(val context: GeneratorContext){
val funSpecs = mutableListOf<FunSpec>()

// Add method signature exts
if(method.methodSignatureFields.isNotEmpty()){
if(method.methodSignatureVariants.isNotEmpty()){
funSpecs += buildAsyncMethodSigExt(method)
funSpecs += buildFutureMethodSigExt(method)
funSpecs += buildBlockingMethodSigExt(method)
Expand All @@ -59,7 +59,7 @@ class UnaryStubExtsBuilder(val context: GeneratorContext){
}

private fun addCoroutineStubExts(funSpecs: MutableList<FunSpec>, method: ProtoMethod){
if(method.methodSignatureFields.isNotEmpty()) {
if(method.methodSignatureVariants.isNotEmpty()) {
funSpecs += buildCoroutineMethodSigExt(method)
}

Expand Down Expand Up @@ -87,50 +87,58 @@ class UnaryStubExtsBuilder(val context: GeneratorContext){

// Method Signature Extensions

private fun buildAsyncMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.asyncStubClassName)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.addResponseObserverParameter(responseClassName)
.returns(UNIT)
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("%N(request, responseObserver)",functionName)
.build()
private fun buildAsyncMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.asyncStubClassName)
.addMethodSignatureParameter(variant, context.schema)
.addResponseObserverParameter(responseClassName)
.returns(UNIT)
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("%N(request, responseObserver)",functionName)
.build()
}
}

private fun buildCoroutineMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.addModifiers(KModifier.SUSPEND)
.receiver(protoService.asyncStubClassName)
.returns(responseClassName)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildCoroutineMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.addModifiers(KModifier.SUSPEND)
.receiver(protoService.asyncStubClassName)
.returns(responseClassName)
.addMethodSignatureParameter(variant, context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)",functionName)
.build()
}
}

private fun buildFutureMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.futureStubClassName)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.returns(CommonClassNames.listenableFuture.parameterizedBy(responseClassName))
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildFutureMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.futureStubClassName)
.addMethodSignatureParameter(variant, context.schema)
.returns(CommonClassNames.listenableFuture.parameterizedBy(responseClassName))
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)",functionName)
.build()
}
}

private fun buildBlockingMethodSigExt(protoMethod: ProtoMethod): FunSpec = with(protoMethod){
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.blockingStubClassName)
.returns(responseClassName)
.addMethodSignatureParameter(methodSignatureFields, context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(methodSignatureFields))
.addStatement("return %N(request)",functionName)
.build()
private fun buildBlockingMethodSigExt(protoMethod: ProtoMethod): List<FunSpec> = with(protoMethod){
methodSignatureVariants.map { variant ->
FunSpec.builder(functionName)
.addKdoc(attachedComments)
.receiver(protoService.blockingStubClassName)
.returns(responseClassName)
.addMethodSignatureParameter(variant, context.schema)
.addCode(requestClassName.requestValueMethodSigCodeBlock(variant))
.addStatement("return %N(request)", functionName)
.build()
}
}

// Lambda Builder Extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,25 @@ class ProtoMethod(
else -> throw IllegalStateException("Unknown method type")
}

val methodSignatureFields: List<DescriptorProtos.FieldDescriptorProto> = getMethodSignatureFields()

val methodSignatureVariants: List<List<DescriptorProtos.FieldDescriptorProto>> = getMethodSignatureVariants()
}

private fun ProtoMethod.getMethodSignatureFields(): List<DescriptorProtos.FieldDescriptorProto> =
descriptorProto.options
private fun ProtoMethod.getMethodSignatureVariants(): List<List<DescriptorProtos.FieldDescriptorProto>> {

val methodSignaturesOptions = descriptorProto.options
.runCatching { getExtension(ClientProto.methodSignature) }
.getOrNull()
?.takeUnless { it.isEmpty() }
?.let { signatureFields ->
requestType.descriptorProto.fieldList
.filter { it.name in signatureFields }
}
.orEmpty()

val signatureVariants = methodSignaturesOptions?.map { variant ->
val variantFields = variant.split(",")

requestType.descriptorProto.fieldList
.filter { it.name in variantFields }
}

return signatureVariants.orEmpty()
}

/**
* Comment parsing is based on the following implementation
Expand Down
2 changes: 2 additions & 0 deletions test-api/src/main/proto/message/test_messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ service __MalformedService__ {
rpc say_hello (L1Message1) returns (L1Message2){
option (google.api.method_signature) = "field";
option (google.api.method_signature) = "nested_message";
option (google.api.method_signature) = "field,nested_message";
};


rpc sayHelloServerStreaming (L1Message1) returns (stream L1Message2){
option (google.api.method_signature) = "field";
option (google.api.method_signature) = "nested_message";
option (google.api.method_signature) = "field,nested_message";
};

rpc sayHelloStreaming (stream L1Message1) returns (stream L1Message2);
Expand Down

0 comments on commit e223ead

Please sign in to comment.