Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nstogner committed Nov 12, 2024
1 parent 0c4f845 commit 416b064
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
2 changes: 1 addition & 1 deletion charts/kubeai/templates/configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ data:
modelServers:
{{- .Values.modelServers | toYaml | nindent 6 }}
modelLoading:
{{- .Values.modelLoaders | toYaml | nindent 6 }}
{{- .Values.modelLoading | toYaml | nindent 6 }}
modelRollouts:
{{- .Values.modelRollouts | toYaml | nindent 6 }}
modelServerPods:
Expand Down
6 changes: 3 additions & 3 deletions internal/modelproxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

log.Println("model:", pr.model)
log.Println("model:", pr.model, "adapter:", pr.adapter)

metricAttrs := metric.WithAttributeSet(attribute.NewSet(
metrics.AttrRequestModel.String(pr.model),
metrics.AttrRequestModel.String(pr.requestedModel),
metrics.AttrRequestType.String(metrics.AttrRequestTypeHTTP),
))
metrics.InferenceRequestsActive.Add(pr.r.Context(), 1, metricAttrs)
Expand All @@ -80,7 +80,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if !modelExists {
pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.model)
pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.requestedModel)
return
}

Expand Down
41 changes: 41 additions & 0 deletions internal/modelproxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,19 @@ func TestHandler(t *testing.T) {
model1 = "model1"
model2 = "model2"

model3 = "model3"
adapter3 = "adapter3"

maxRetries = 3
)
models := map[string]testMockModel{
model1: {},
model2: {},
model3: {
adapters: map[string]bool{
adapter3: true,
},
},
}

type metricsTestSpec struct {
Expand Down Expand Up @@ -69,6 +77,27 @@ func TestHandler(t *testing.T) {
},
expBackendRequestCount: 1,
},
"happy 200 model+adapter in body": {
reqBody: fmt.Sprintf(`{"model":%q}`, model3+"/"+adapter3),
backendCode: http.StatusOK,
backendBody: `{"result":"ok"}`,
expCode: http.StatusOK,
expBody: `{"result":"ok"}`,
expMetrics: &metricsTestSpec{
expModel: model3 + "/" + adapter3,
},
expBackendRequestCount: 1,
},
"404 model+adapter in body but missing adapter": {
reqBody: fmt.Sprintf(`{"model":%q}`, model1+"/no-such-adapter"),
expCode: http.StatusNotFound,
expBody: `{"error":"model not found: model1/no-such-adapter"}` + "\n",
},
"404 model+adapter in header but missing adapter": {
reqHeaders: map[string]string{"X-Model": model1 + "/no-such-adapter"},
expCode: http.StatusNotFound,
expBody: `{"error":"model not found: model1/no-such-adapter"}` + "\n",
},
"happy 200 model in header": {
reqBody: "{}",
reqHeaders: map[string]string{"X-Model": model1},
Expand All @@ -81,6 +110,18 @@ func TestHandler(t *testing.T) {
},
expBackendRequestCount: 1,
},
"happy 200 model+adapter in header": {
reqBody: "{}",
reqHeaders: map[string]string{"X-Model": model3 + "/" + adapter3},
backendCode: http.StatusOK,
backendBody: `{"result":"ok"}`,
expCode: http.StatusOK,
expBody: `{"result":"ok"}`,
expMetrics: &metricsTestSpec{
expModel: model3 + "/" + adapter3,
},
expBackendRequestCount: 1,
},
"happy 200 only model in form data": {
reqHeaders: map[string]string{"Content-Type": "multipart/form-data; boundary=12345"},
reqBody: fmt.Sprintf(
Expand Down
14 changes: 9 additions & 5 deletions internal/modelproxy/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ type proxyRequest struct {

selectors []string

id string
status int
model string
adapter string
attempt int
id string
status int
requestedModel string
model string
adapter string
attempt int
}

func newProxyRequest(r *http.Request) *proxyRequest {
Expand All @@ -52,6 +53,7 @@ func (pr *proxyRequest) parse() error {
// Try to get the model from the header first
if headerModel := pr.r.Header.Get("X-Model"); headerModel != "" {
pr.model, pr.adapter = apiutils.SplitModelAdapter(headerModel)
pr.requestedModel = headerModel
// Save the body content (required to support retries of the proxy request)
body, err := io.ReadAll(pr.r.Body)
if err != nil {
Expand Down Expand Up @@ -110,6 +112,7 @@ func (pr *proxyRequest) parse() error {
return fmt.Errorf("reading multipart form value: %w", err)
}
pr.model, pr.adapter = apiutils.SplitModelAdapter(string(value))
pr.requestedModel = string(value)
// WORKAROUND ALERT:
// Omit the "model" field from the proxy request to avoid FasterWhisper validation issues:
// See https://github.com/fedirz/faster-whisper-server/issues/71
Expand Down Expand Up @@ -149,6 +152,7 @@ func (pr *proxyRequest) parse() error {
return fmt.Errorf("unmarshal json: %w", err)
}
pr.model, pr.adapter = apiutils.SplitModelAdapter(payload.Model)
pr.requestedModel = payload.Model

if pr.model == "" {
return fmt.Errorf("no model specified")
Expand Down

0 comments on commit 416b064

Please sign in to comment.