Skip to content

Commit

Permalink
Merge pull request #13 from kerthcet/feat/api-defination
Browse files Browse the repository at this point in the history
Add webhook to Playground
  • Loading branch information
InftyAI-Agent authored Jul 15, 2024
2 parents 94a85fe + 01091cd commit dffa415
Show file tree
Hide file tree
Showing 19 changed files with 300 additions and 30 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# ENVTEST_K8S_VERSION refers to the version of kubebuilder assets to be downloaded by envtest binary.
ENVTEST_K8S_VERSION = 1.28.3

Expand Down
5 changes: 3 additions & 2 deletions api/inference/v1alpha1/playground_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ type PlaygroundSpec struct {
// ModelClaim and multiModelsClaims are exclusive configured.
// Note: properties (nodeSelectors, resources, e.g.) of the model flavors
// will be applied to the workload if not exist.
ModelClaim api.ModelClaim `json:"modelClaim"`
// +optional
ModelClaim *api.ModelClaim `json:"modelClaim,omitempty"`
// MultiModelsClaims represents multiple modelClaim, which is useful when different
// sub-workload has different accelerator requirements, like the state-of-the-art
// technology called splitwise, the workload template is shared by both.
// ModelClaim and multiModelsClaims are exclusive configured.
// +kubebuilder:validation:MinItems=1
// +optional
MultiModelsClaims []api.MultiModelsClaim `json:"multiModelsClaims,omitempty"`
// BackendConfig represents the inference backend configuration
// under the hood, e.g. vLLM, which is the default backend.
Expand Down
2 changes: 1 addition & 1 deletion api/inference/v1alpha1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type BackendConfig struct {
// +kubebuilder:validation:Enum={vllm}
// +kubebuilder:default=vllm
// +optional
Name *BackendName `json:"name"`
Name *BackendName `json:"name,omitempty"`
// Version represents the backend version if you want a different one
// from the default version.
// +optional
Expand Down
6 changes: 5 additions & 1 deletion api/inference/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions api/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ type ModelName string
// ModelClaim represents the references to one model.
// It's a simple config for most of the cases compared to multiModelsClaim.
type ModelClaim struct {
// ModelNames represents a list of models, there maybe multiple models here
// ModelName represents a list of models, there maybe multiple models here
// to support state-of-the-art technologies like speculative decoding.
ModelNames ModelName `json:"modelNames,omitempty"`
ModelName ModelName `json:"modelName,omitempty"`
// InferenceFlavors represents a list of flavors with fungibility supported
// to serve the model.
// - If not set and multiple models claimed, apply with the 0-index model by default.
Expand Down Expand Up @@ -123,7 +123,7 @@ type MultiModelsClaim struct {
// This is mostly designed for state-of-the-art technology called splitwise, the prefill
// and decode phase will be separated and requires different accelerators.
// The sum of the rates should be divisible by replicas.
Rate *int `json:"rate,omitempty"`
Rate *int32 `json:"rate,omitempty"`
}

// ModelSpec defines the desired state of Model
Expand Down
2 changes: 1 addition & 1 deletion api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ func setupControllers(mgr ctrl.Manager, certsReady chan struct{}) {
<-certsReady
setupLog.Info("certs ready")

if err := (&inferencecontroller.ServiceReconciler{
if err := (&controller.ModelReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
}).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "Service")
setupLog.Error(err, "unable to create controller", "controller", "Model")
os.Exit(1)
}
if err := (&inferencecontroller.PlaygroundReconciler{
Expand All @@ -145,17 +145,22 @@ func setupControllers(mgr ctrl.Manager, certsReady chan struct{}) {
setupLog.Error(err, "unable to create controller", "controller", "Playground")
os.Exit(1)
}
if err := (&controller.ModelReconciler{
if err := (&inferencecontroller.ServiceReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
}).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "Model")
setupLog.Error(err, "unable to create controller", "controller", "Service")
os.Exit(1)
}

if os.Getenv("ENABLE_WEBHOOKS") != "false" {
if err := webhook.SetupModelWebhook(mgr); err != nil {
setupLog.Error(err, "unable to create webhook", "webhook", "Model")
os.Exit(1)
}
if err := webhook.SetupPlaygroundWebhook(mgr); err != nil {
setupLog.Error(err, "unable to create webhook", "webhook", "Playground")
os.Exit(1)
}
}
}
8 changes: 3 additions & 5 deletions config/crd/bases/inference.llmaz.io_playgrounds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ spec:
items:
type: string
type: array
modelNames:
modelName:
description: |-
ModelNames represents a list of models, there maybe multiple models here
ModelName represents a list of models, there maybe multiple models here
to support state-of-the-art technologies like speculative decoding.
type: string
type: object
Expand Down Expand Up @@ -289,17 +289,15 @@ spec:
This is mostly designed for state-of-the-art technology called splitwise, the prefill
and decode phase will be separated and requires different accelerators.
The sum of the rates should be divisible by replicas.
format: int32
type: integer
type: object
minItems: 1
type: array
replicas:
default: 1
description: Replicas represents the replica number of inference workloads.
format: int32
type: integer
required:
- modelClaim
type: object
status:
description: PlaygroundStatus defines the observed state of Playground
Expand Down
1 change: 1 addition & 0 deletions config/crd/bases/inference.llmaz.io_services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ spec:
This is mostly designed for state-of-the-art technology called splitwise, the prefill
and decode phase will be separated and requires different accelerators.
The sum of the rates should be divisible by replicas.
format: int32
type: integer
type: object
minItems: 1
Expand Down
4 changes: 2 additions & 2 deletions config/samples/_v1alpha1_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ metadata:
app.kubernetes.io/part-of: llmaz
app.kubernetes.io/managed-by: kustomize
app.kubernetes.io/created-by: llmaz
name: llama2-7b
name: llama3-8b
spec:
familyName: "llama2"
familyName: "llama3"
dataSource:
url: https://<url>
inferenceFlavors:
Expand Down
4 changes: 2 additions & 2 deletions config/samples/inference_v1alpha1_playground.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ metadata:
app.kubernetes.io/part-of: llmaz
app.kubernetes.io/managed-by: kustomize
app.kubernetes.io/created-by: llmaz
name: playground-sample
name: playground-llama3-8b
spec:
replicas: 1
modelClaim:
modelName: "llama2-7b"
modelName: "llama3-8b"
4 changes: 2 additions & 2 deletions config/samples/inference_v1alpha1_service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ metadata:
app.kubernetes.io/part-of: llmaz
app.kubernetes.io/managed-by: kustomize
app.kubernetes.io/created-by: llmaz
name: service-sample
name: service-llama3-8b
spec:
multiModelsClaims:
- modelNames:
- "llama2-7b"
- "llama3-8b"
workloadTemplate:
replicas: 1
leaderWorkerTemplate:
Expand Down
40 changes: 40 additions & 0 deletions config/webhook/manifests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ webhooks:
resources:
- models
sideEffects: None
- admissionReviewVersions:
- v1
clientConfig:
service:
name: webhook-service
namespace: system
path: /mutate-inference-llmaz-io-v1alpha1-playground
failurePolicy: Fail
name: mplayground.kb.io
rules:
- apiGroups:
- inference.llmaz.io
apiVersions:
- v1alpha1
operations:
- CREATE
- UPDATE
resources:
- playgrounds
sideEffects: None
---
apiVersion: admissionregistration.k8s.io/v1
kind: ValidatingWebhookConfiguration
Expand All @@ -50,3 +70,23 @@ webhooks:
resources:
- models
sideEffects: None
- admissionReviewVersions:
- v1
clientConfig:
service:
name: webhook-service
namespace: system
path: /validate-inference-llmaz-io-v1alpha1-playground
failurePolicy: Fail
name: vplayground.kb.io
rules:
- apiGroups:
- inference.llmaz.io
apiVersions:
- v1alpha1
operations:
- CREATE
- UPDATE
resources:
- playgrounds
sideEffects: None
81 changes: 81 additions & 0 deletions internal/webhook/playground_webhook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
Copyright 2024.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package webhook

import (
"context"

"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

inferenceapi "inftyai.com/llmaz/api/inference/v1alpha1"
)

type PlaygroundWebhook struct{}

// SetupPlaygroundWebhook will setup the manager to manage the webhooks
func SetupPlaygroundWebhook(mgr ctrl.Manager) error {
return ctrl.NewWebhookManagedBy(mgr).
For(&inferenceapi.Playground{}).
WithDefaulter(&PlaygroundWebhook{}).
WithValidator(&PlaygroundWebhook{}).
Complete()
}

//+kubebuilder:webhook:path=/mutate-inference-llmaz-io-v1alpha1-playground,mutating=true,failurePolicy=fail,sideEffects=None,groups=inference.llmaz.io,resources=playgrounds,verbs=create;update,versions=v1alpha1,name=mplayground.kb.io,admissionReviewVersions=v1

var _ webhook.CustomDefaulter = &PlaygroundWebhook{}

// Default implements webhook.Defaulter so a webhook will be registered for the type
func (w *PlaygroundWebhook) Default(ctx context.Context, obj runtime.Object) error {
return nil
}

//+kubebuilder:webhook:path=/validate-inference-llmaz-io-v1alpha1-playground,mutating=false,failurePolicy=fail,sideEffects=None,groups=inference.llmaz.io,resources=playgrounds,verbs=create;update,versions=v1alpha1,name=vplayground.kb.io,admissionReviewVersions=v1

var _ webhook.CustomValidator = &PlaygroundWebhook{}

// ValidateCreate implements webhook.Validator so a webhook will be registered for the type
func (w *PlaygroundWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
warnings, allErrs := w.generateValidate(obj)
return warnings, allErrs.ToAggregate()
}

// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type
func (w *PlaygroundWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
warnings, allErrs := w.generateValidate(newObj)
return warnings, allErrs.ToAggregate()
}

// ValidateDelete implements webhook.Validator so a webhook will be registered for the type
func (w *PlaygroundWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
return nil, nil
}

func (w *PlaygroundWebhook) generateValidate(obj runtime.Object) (admission.Warnings, field.ErrorList) {
playground := obj.(*inferenceapi.Playground)
specPath := field.NewPath("spec")

var allErrs field.ErrorList
if playground.Spec.ModelClaim == nil && len(playground.Spec.MultiModelsClaims) == 0 {
allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and multiModelsClaims couldn't be both empty"))
}
return nil, allErrs
}
9 changes: 5 additions & 4 deletions test/integration/webhook/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ var _ = ginkgo.Describe("model default and validation", func() {
},
ginkgo.Entry("apply model family name", &testDefaultingCase{
model: func() *api.Model {
return wrapper.MakeModel("llama2-7b").DataSourceWithModel("llama2", "Huggingface").FamilyName("llama2").Obj()
return wrapper.MakeModel("llama3-8b").DataSourceWithModel("meta-llama/meta-llama-3-8b", "Huggingface").FamilyName("llama3").Obj()
},
wantModel: func() *api.Model {
return wrapper.MakeModel("llama2-7b").FamilyName("llama2").DataSourceWithModel("llama2", "Huggingface").Label(api.ModelFamilyNameLabelKey, "llama2").Obj()
return wrapper.MakeModel("llama3-8b").DataSourceWithModel("meta-llama/meta-llama-3-8b", "Huggingface").FamilyName("llama3").Label(api.ModelFamilyNameLabelKey, "llama3").Obj()
},
}),
)
Expand All @@ -64,6 +64,7 @@ var _ = ginkgo.Describe("model default and validation", func() {
model func() *api.Model
failed bool
}
// TODO: add more testCases to cover update.
ginkgo.DescribeTable("test validating",
func(tc *testValidatingCase) {
if tc.failed {
Expand All @@ -74,13 +75,13 @@ var _ = ginkgo.Describe("model default and validation", func() {
},
ginkgo.Entry("normal model creation", &testValidatingCase{
model: func() *api.Model {
return wrapper.MakeModel("llama2-7b").FamilyName("llama2").DataSourceWithModel("llama2", "Huggingface").Obj()
return wrapper.MakeModel("llama3-8b").FamilyName("llama3").DataSourceWithModel("meta-llama/meta-llama-3-8b", "Huggingface").Obj()
},
failed: false,
}),
ginkgo.Entry("no data source configured", &testValidatingCase{
model: func() *api.Model {
return wrapper.MakeModel("llama2-7b").FamilyName("llama2").Obj()
return wrapper.MakeModel("llama3-8b").FamilyName("llama3").Obj()
},
failed: true,
}),
Expand Down
Loading

0 comments on commit dffa415

Please sign in to comment.