From 9409b4414c37a173f8467bea52f383cde5a84393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Franke?= Date: Tue, 13 Aug 2024 16:57:23 +0200 Subject: [PATCH] Convert contract negotiation handlers to return err (#47) This converts the contract negotiation handlers to return HTTPReturnErrors. In the future we might want to make the statemachine tailor its own errors, but this makes future improvements possible. --- dsp/contract_handlers.go | 221 +++++++++++++++++++++++---------------- dsp/routing.go | 32 +++--- 2 files changed, 149 insertions(+), 104 deletions(-) diff --git a/dsp/contract_handlers.go b/dsp/contract_handlers.go index 87c8b4e..3ad6ee1 100644 --- a/dsp/contract_handlers.go +++ b/dsp/contract_handlers.go @@ -28,44 +28,92 @@ import ( "github.com/google/uuid" ) -func (dh *dspHandlers) providerContractStateHandler(w http.ResponseWriter, req *http.Request) { +type ContractError struct { + status int + contract *statemachine.Contract + dspCode string + reason string + err string +} + +func (ce ContractError) Error() string { return ce.err } +func (ce ContractError) StatusCode() int { return ce.status } +func (ce ContractError) ErrorType() string { return "dspace:ContractNegotiationError" } +func (ce ContractError) DSPCode() string { return ce.dspCode } + +func (ce ContractError) Description() []shared.Multilanguage { + return []shared.Multilanguage{{Value: ce.reason, Language: "en"}} +} + +func (ce ContractError) Reason() []shared.Multilanguage { + return []shared.Multilanguage{{Value: ce.reason, Language: "en"}} +} + +func (ce ContractError) ProviderPID() string { + if ce.contract == nil { + return "" + } + return ce.contract.GetProviderPID().URN() +} + +func (ce ContractError) ConsumerPID() string { + if ce.contract == nil { + return "" + } + return ce.contract.GetConsumerPID().URN() +} + +func contractError( + err string, statusCode int, dspCode string, reason string, contract *statemachine.Contract, +) ContractError { + return ContractError{ + status: statusCode, + contract: contract, + dspCode: dspCode, + reason: reason, + err: err, + } +} + +func (dh *dspHandlers) providerContractStateHandler(w http.ResponseWriter, req *http.Request) error { logger := logging.Extract(req.Context()) providerPID, err := uuid.Parse(req.PathValue("providerPID")) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid provider PID") - return + return contractError("invalid provider ID", http.StatusBadRequest, "400", "Invalid provider PID", nil) } contract, err := dh.store.GetProviderContract(req.Context(), providerPID) if err != nil { - returnError(w, http.StatusNotFound, "no contract found") - return + return contractError(err.Error(), http.StatusNotFound, "404", "Contract not found", nil) } + if err := shared.EncodeValid(w, req, http.StatusOK, contract.GetContractNegotiation()); err != nil { - logger.Error("couldn't serve contract state: %w", "err", err) + logger.Error("couldn't serve contract state", "err", err) } + return nil } -func (dh *dspHandlers) providerContractRequestHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) providerContractRequestHandler(w http.ResponseWriter, req *http.Request) error { ctx, logger := logging.InjectLabels(req.Context(), "handler", "providerContractRequestHandler") req = req.WithContext(ctx) contractReq, err := shared.DecodeValid[shared.ContractRequestMessage](req) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError(fmt.Sprintf("invalid request message: %s", err.Error()), + http.StatusBadRequest, "400", "Invalid request", nil) } logger.Debug("Got contract request", "req", contractReq) consumerPID, err := uuid.Parse(contractReq.ConsumerPID) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request: ConsumerPID is not a UUID") - return + return contractError(fmt.Sprintf("Invalid consumer ID %s: %s", contractReq.ConsumerPID, err.Error()), + http.StatusBadRequest, "400", "Invalid request: ConsumerPID is not a UUID", nil) } + // TODO: Check if callback URL is reachable? cbURL, err := url.Parse(contractReq.CallbackAddress) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request: Non-valid callback URL.") - return + return contractError(fmt.Sprintf("Invalid callback URL %s: %s", contractReq.CallbackAddress, err.Error()), + http.StatusBadRequest, "400", "Invalid request: Non-valid callback URL.", nil) } // TODO: Maybe make a function in the statemachine that parses the contract request message. @@ -81,38 +129,40 @@ func (dh *dspHandlers) providerContractRequestHandler(w http.ResponseWriter, req req = req.WithContext(ctx) if err != nil { - returnError(w, http.StatusInternalServerError, "Couldn't create contract") - return + return contractError(fmt.Sprintf("couldn't create contract: %s", err.Error()), + http.StatusInternalServerError, "500", "Failed to create contract", nil) } ctx, nextState, err := pState.Recv(req.Context(), contractReq) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError( + fmt.Sprintf("couldn't receive message: %s", err.Error()), + http.StatusBadRequest, "400", "Invalid request", pState.GetContract(), + ) } req = req.WithContext(ctx) apply, err := nextState.Send(req.Context()) if err != nil { - returnError(w, http.StatusInternalServerError, "Not able to progress state") - return + return contractError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), + http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetContract()) } if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetContractNegotiation()); err != nil { logger.Error("Couldn't serve response", "err", err) - returnError(w, http.StatusInternalServerError, "Failed to serve response") } go apply() + return nil } func progressContractState[T any]( dh *dspHandlers, w http.ResponseWriter, req *http.Request, role statemachine.DataspaceRole, rawPID string, -) { +) error { logger := logging.Extract(req.Context()) pid, err := uuid.Parse(rawPID) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid PID") - return + return contractError(fmt.Sprintf("Invalid PID %s: %s", rawPID, err.Error()), + http.StatusBadRequest, "400", "Invalid request: PID is not a UUID", nil) } var contract *statemachine.Contract @@ -125,14 +175,14 @@ func progressContractState[T any]( panic(fmt.Sprintf("unexpected statemachine.ContractRole: %#v", role)) } if err != nil { - returnError(w, http.StatusNotFound, "Contract not found") - return + return contractError(fmt.Sprintf("%d contract %s not found: %s", role, pid, err), + http.StatusNotFound, "404", "Contract not found", nil) } msg, err := shared.DecodeValid[T](req) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError(fmt.Sprintf("could not decode message: %s", err), + http.StatusBadRequest, "400", "Invalid request", contract) } logger.Debug("Got contract message", "req", msg) @@ -143,96 +193,94 @@ func progressContractState[T any]( ctx, nextState, err := pState.Recv(req.Context(), msg) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError(fmt.Sprintf("invalid request: %s", err), + http.StatusBadRequest, "400", "Invalid request", pState.GetContract()) } req = req.WithContext(ctx) apply, err := nextState.Send(req.Context()) if err != nil { - returnError(w, http.StatusInternalServerError, "Not able to progress state") - return + return contractError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), + http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetContract()) } if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetContractNegotiation()); err != nil { logger.Error("Couldn't serve response", "err", err) - returnError(w, http.StatusInternalServerError, "Failed to serve response") } go apply() + return nil } -func (dh *dspHandlers) providerContractSpecificRequestHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) providerContractSpecificRequestHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractSpecificRequestHandler") req = req.WithContext(ctx) - progressContractState[shared.ContractRequestMessage]( + return progressContractState[shared.ContractRequestMessage]( dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), ) } -func (dh *dspHandlers) providerContractEventHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) providerContractEventHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractEventHandler") req = req.WithContext(ctx) - progressContractState[shared.ContractNegotiationEventMessage]( + return progressContractState[shared.ContractNegotiationEventMessage]( dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), ) } -func (dh *dspHandlers) providerContractVerificationHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) providerContractVerificationHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractVerificationHandler") req = req.WithContext(ctx) - progressContractState[shared.ContractAgreementVerificationMessage]( + return progressContractState[shared.ContractAgreementVerificationMessage]( dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), ) } // TODO: Implement terminations. -func (dh *dspHandlers) providerContractTerminationHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) providerContractTerminationHandler(w http.ResponseWriter, req *http.Request) error { ctx, logger := logging.InjectLabels(req.Context(), "handler", "providerContractTerminationHandler") req = req.WithContext(ctx) providerPID := req.PathValue("providerPID") if providerPID == "" { - returnError(w, http.StatusBadRequest, "Missing provider PID") - return + return contractError("missing provider PID", http.StatusBadRequest, "400", "Missing provider PID", nil) } reqBody, err := io.ReadAll(req.Body) if err != nil { - returnError(w, http.StatusBadRequest, "Could not read body") - return + return contractError("could not read body", http.StatusInternalServerError, "500", "could not read body", nil) } verification, err := shared.UnmarshalAndValidate( req.Context(), reqBody, shared.ContractNegotiationTerminationMessage{}) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError("invalid request", http.StatusBadGateway, "400", "invalid request", nil) } logger.Debug("Got contract termination", "termination", verification) // If all goes well, we just return a 200 w.WriteHeader(http.StatusOK) + return nil } -func (dh *dspHandlers) consumerContractOfferHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) consumerContractOfferHandler(w http.ResponseWriter, req *http.Request) error { ctx, logger := logging.InjectLabels(req.Context(), "handler", "consumerContractOfferHandler") req = req.WithContext(ctx) contractOffer, err := shared.DecodeValid[shared.ContractOfferMessage](req) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError(fmt.Sprintf("invalid request message: %s", err.Error()), + http.StatusBadRequest, "400", "Invalid request", nil) } logger.Debug("Got contract offer", "offer", contractOffer) providerPID, err := uuid.Parse(contractOffer.ProviderPID) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request: ProviderPID is not a UUID") - return + return contractError(fmt.Sprintf("Invalid providerPID ID %s: %s", contractOffer.ProviderPID, err.Error()), + http.StatusBadRequest, "400", "Invalid request: ProviderPID is not a UUID", nil) } cbURL, err := url.Parse(contractOffer.CallbackAddress) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request: non-valid callback URL.") - return + return contractError(fmt.Sprintf("Invalid callback URL %s: %s", contractOffer.CallbackAddress, err.Error()), + http.StatusBadRequest, "400", "Invalid request: Non-valid callback URL.", nil) } selfURL, err := url.Parse(dh.selfURL.String()) @@ -251,90 +299,88 @@ func (dh *dspHandlers) consumerContractOfferHandler(w http.ResponseWriter, req * logger = logging.Extract(ctx) req = req.WithContext(ctx) if err != nil { - returnError(w, http.StatusInternalServerError, "Couldn't create contract") - return + return contractError(fmt.Sprintf("couldn't create contract: %s", err.Error()), + http.StatusInternalServerError, "500", "Failed to create contract", nil) } ctx, nextState, err := cState.Recv(req.Context(), contractOffer) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError(fmt.Sprintf("couldn't receive message: %s", err.Error()), + http.StatusBadRequest, "400", "Invalid request", cState.GetContract()) } req = req.WithContext(ctx) apply, err := nextState.Send(req.Context()) if err != nil { - returnError(w, http.StatusInternalServerError, "Not able to progress state") - return + return contractError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), + http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetContract()) } if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetContractNegotiation()); err != nil { logger.Error("Couldn't serve response", "err", err) - returnError(w, http.StatusInternalServerError, "Failed to serve response") } go apply() + return nil } -func (dh *dspHandlers) consumerContractSpecificOfferHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) consumerContractSpecificOfferHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractSpecificOfferHandler") req = req.WithContext(ctx) - progressContractState[shared.ContractOfferMessage]( + return progressContractState[shared.ContractOfferMessage]( dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), ) } -func (dh *dspHandlers) consumerContractAgreementHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) consumerContractAgreementHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractAgreementHandler") req = req.WithContext(ctx) - progressContractState[shared.ContractAgreementMessage]( + return progressContractState[shared.ContractAgreementMessage]( dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), ) } -func (dh *dspHandlers) consumerContractEventHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) consumerContractEventHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractEventHandler") req = req.WithContext(ctx) - progressContractState[shared.ContractNegotiationEventMessage]( + return progressContractState[shared.ContractNegotiationEventMessage]( dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), ) } // TODO: Handle termination in the statemachine. -func (dh *dspHandlers) consumerContractTerminationHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) consumerContractTerminationHandler(w http.ResponseWriter, req *http.Request) error { ctx, logger := logging.InjectLabels(req.Context(), "handler", "consumerContractTerminationHandler") req = req.WithContext(ctx) consumerPID := req.PathValue("consumerPID") if consumerPID == "" { - returnError(w, http.StatusBadRequest, "Missing consumer PID") - return + return contractError("missing consumer PID", http.StatusBadRequest, "400", "Missing consumer PID", nil) } reqBody, err := io.ReadAll(req.Body) if err != nil { - returnError(w, http.StatusBadRequest, "Could not read body") - return + return contractError("could not read body", http.StatusInternalServerError, "500", "could not read body", nil) } // This should have the event FINALIZED termination, err := shared.UnmarshalAndValidate(req.Context(), reqBody, shared.ContractNegotiationTerminationMessage{}) if err != nil { - returnError(w, http.StatusBadRequest, "Invalid request") - return + return contractError("invalid request", http.StatusBadGateway, "400", "invalid request", nil) } logger.Debug("Got contract event", "termination", termination) // If all goes well, we just return a 200 w.WriteHeader(http.StatusOK) + + return nil } -func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWriter, req *http.Request) error { ctx, logger := logging.InjectLabels(req.Context(), "handler", "triggerConsumerContractRequestHandler") req = req.WithContext(ctx) datasetID, err := uuid.Parse(req.PathValue("datasetID")) if err != nil { - returnError(w, http.StatusBadRequest, "Dataset ID is not a UUID") - return + return fmt.Errorf("Dataset ID is not a UUID") } logger.Debug("Got trigger request to start contract negotiation") @@ -364,36 +410,32 @@ func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWrit ) req = req.WithContext(ctx) if err != nil { - returnError(w, http.StatusInternalServerError, "Couldn't create contract") - return + return err } apply, err := cInit.Send(req.Context()) if err != nil { - returnError(w, http.StatusInternalServerError, err.Error()) - return + return err } if err := shared.EncodeValid(w, req, http.StatusOK, cInit.GetContractNegotiation()); err != nil { logger.Error("Couldn't serve response", "err", err) - returnError(w, http.StatusInternalServerError, "Failed to serve response") } go apply() + return nil } -func (dh *dspHandlers) triggerTransferRequestHandler(w http.ResponseWriter, req *http.Request) { +func (dh *dspHandlers) triggerTransferRequestHandler(w http.ResponseWriter, req *http.Request) error { ctx, logger := logging.InjectLabels(req.Context(), "handler", "triggerConsumerContractRequestHandler") req = req.WithContext(ctx) pid, err := uuid.Parse(req.PathValue("contractProviderPID")) if err != nil { - returnError(w, http.StatusBadRequest, "Not a PID") - return + return err } con, err := dh.store.GetProviderContract(ctx, pid) if err != nil { - returnError(w, http.StatusNotFound, "No contract found") - return + return err } logger.Debug("Got trigger request to start contract negotiation") selfURL, err := url.Parse(dh.selfURL.String()) @@ -419,18 +461,17 @@ func (dh *dspHandlers) triggerTransferRequestHandler(w http.ResponseWriter, req ) req = req.WithContext(ctx) if err != nil { - returnError(w, http.StatusInternalServerError, "Couldn't create transfer request") - return + return err } apply, err := cInit.Send(req.Context()) if err != nil { - returnError(w, http.StatusInternalServerError, err.Error()) - return + return err } if err := shared.EncodeValid(w, req, http.StatusOK, cInit.GetTransferProcess()); err != nil { logger.Error("Couldn't serve response", "err", err) - returnError(w, http.StatusInternalServerError, "Failed to serve response") } go apply() + + return nil } diff --git a/dsp/routing.go b/dsp/routing.go index 3cec35b..d369ddb 100644 --- a/dsp/routing.go +++ b/dsp/routing.go @@ -47,19 +47,23 @@ func GetDSPRoutes( mux.Handle("GET /catalog/datasets/{id}", WrapHandlerWithError(ch.datasetRequestHandler)) // Contract negotiation endpoints - mux.HandleFunc("GET /negotiations/{providerPID}", ch.providerContractStateHandler) - mux.HandleFunc("POST /negotiations/request", ch.providerContractRequestHandler) - mux.HandleFunc("POST /negotiations/{providerPID}/request", ch.providerContractSpecificRequestHandler) - mux.HandleFunc("POST /negotiations/{providerPID}/events", ch.providerContractEventHandler) - mux.HandleFunc("POST /negotiations/{providerPID}/agreement/verification", ch.providerContractVerificationHandler) - mux.HandleFunc("POST /negotiations/{providerPID}/termination", ch.providerContractTerminationHandler) + mux.Handle("GET /negotiations/{providerPID}", WrapHandlerWithError(ch.providerContractStateHandler)) + mux.Handle("POST /negotiations/request", WrapHandlerWithError(ch.providerContractRequestHandler)) + mux.Handle("POST /negotiations/{providerPID}/request", WrapHandlerWithError(ch.providerContractSpecificRequestHandler)) + mux.Handle("POST /negotiations/{providerPID}/events", WrapHandlerWithError(ch.providerContractEventHandler)) + mux.Handle("POST /negotiations/{providerPID}/agreement/verification", + WrapHandlerWithError(ch.providerContractVerificationHandler)) + mux.Handle("POST /negotiations/{providerPID}/termination", WrapHandlerWithError(ch.providerContractTerminationHandler)) - // Contract negotiation consumer callbacks - mux.HandleFunc("POST /negotiations/offers", ch.consumerContractOfferHandler) - mux.HandleFunc("POST /callback/negotiations/{consumerPID}/offers", ch.consumerContractSpecificOfferHandler) - mux.HandleFunc("POST /callback/negotiations/{consumerPID}/agreement", ch.consumerContractAgreementHandler) - mux.HandleFunc("POST /callback/negotiations/{consumerPID}/events", ch.consumerContractEventHandler) - mux.HandleFunc("POST /callback/negotiations/{consumerPID}/termination", ch.consumerContractTerminationHandler) + // Contract negotiation consumer callbacks) + mux.Handle("POST /negotiations/offers", WrapHandlerWithError(ch.consumerContractOfferHandler)) + mux.Handle("POST /callback/negotiations/{consumerPID}/offers", + WrapHandlerWithError(ch.consumerContractSpecificOfferHandler)) + mux.Handle("POST /callback/negotiations/{consumerPID}/agreement", + WrapHandlerWithError(ch.consumerContractAgreementHandler)) + mux.Handle("POST /callback/negotiations/{consumerPID}/events", WrapHandlerWithError(ch.consumerContractEventHandler)) + mux.Handle("POST /callback/negotiations/{consumerPID}/termination", + WrapHandlerWithError(ch.consumerContractTerminationHandler)) // Transfer process endpoints mux.HandleFunc("GET /transfers/{providerPID}", ch.providerTransferProcessHandler) @@ -74,8 +78,8 @@ func GetDSPRoutes( mux.HandleFunc("POST /callback/transfers/{consumerPID}/termination", ch.consumerTransferTerminationHandler) mux.HandleFunc("POST /callback/transfers/{consumerPID}/suspension", ch.consumerTransferSuspensionHandler) - mux.HandleFunc("GET /triggerconsumer/{datasetID}", ch.triggerConsumerContractRequestHandler) - mux.HandleFunc("GET /triggertransfer/{contractProviderPID}", ch.triggerTransferRequestHandler) + mux.Handle("GET /triggerconsumer/{datasetID}", WrapHandlerWithError(ch.triggerConsumerContractRequestHandler)) + mux.Handle("GET /triggertransfer/{contractProviderPID}", WrapHandlerWithError(ch.triggerTransferRequestHandler)) mux.HandleFunc("GET /completetransfer/{providerPID}", ch.completeTransferRequestHandler) return mux