diff --git a/dsp/contract_handlers.go b/dsp/contract_handlers.go index e5ad239..3d6043c 100644 --- a/dsp/contract_handlers.go +++ b/dsp/contract_handlers.go @@ -16,7 +16,6 @@ package dsp import ( "fmt" - "io" "net/http" "net/url" "path" @@ -236,29 +235,12 @@ func (dh *dspHandlers) providerContractVerificationHandler(w http.ResponseWriter ) } -// TODO: Implement terminations. func (dh *dspHandlers) providerContractTerminationHandler(w http.ResponseWriter, req *http.Request) error { - ctx, logger := logging.InjectLabels(req.Context(), "handler", "providerContractTerminationHandler") + ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractVerificationHandler") req = req.WithContext(ctx) - providerPID := req.PathValue("providerPID") - if providerPID == "" { - return contractError("missing provider PID", http.StatusBadRequest, "400", "Missing provider PID", nil) - } - reqBody, err := io.ReadAll(req.Body) - if err != nil { - return contractError("could not read body", http.StatusInternalServerError, "500", "could not read body", nil) - } - verification, err := shared.UnmarshalAndValidate( - req.Context(), reqBody, shared.ContractNegotiationTerminationMessage{}) - if err != nil { - return contractError("invalid request", http.StatusBadGateway, "400", "invalid request", nil) - } - - logger.Debug("Got contract termination", "termination", verification) - - // If all goes well, we just return a 200 - w.WriteHeader(http.StatusOK) - return nil + return progressContractState[shared.ContractNegotiationTerminationMessage]( + dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), + ) } func (dh *dspHandlers) consumerContractOfferHandler(w http.ResponseWriter, req *http.Request) error { @@ -347,31 +329,12 @@ func (dh *dspHandlers) consumerContractEventHandler(w http.ResponseWriter, req * ) } -// TODO: Handle termination in the statemachine. func (dh *dspHandlers) consumerContractTerminationHandler(w http.ResponseWriter, req *http.Request) error { - ctx, logger := logging.InjectLabels(req.Context(), "handler", "consumerContractTerminationHandler") + ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractEventHandler") req = req.WithContext(ctx) - consumerPID := req.PathValue("consumerPID") - if consumerPID == "" { - return contractError("missing consumer PID", http.StatusBadRequest, "400", "Missing consumer PID", nil) - } - reqBody, err := io.ReadAll(req.Body) - if err != nil { - return contractError("could not read body", http.StatusInternalServerError, "500", "could not read body", nil) - } - - // This should have the event FINALIZED - termination, err := shared.UnmarshalAndValidate(req.Context(), reqBody, shared.ContractNegotiationTerminationMessage{}) - if err != nil { - return contractError("invalid request", http.StatusBadGateway, "400", "invalid request", nil) - } - - logger.Debug("Got contract event", "termination", termination) - - // If all goes well, we just return a 200 - w.WriteHeader(http.StatusOK) - - return nil + return progressContractState[shared.ContractNegotiationTerminationMessage]( + dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), + ) } func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWriter, req *http.Request) error { diff --git a/dsp/shared/contract_types.go b/dsp/shared/contract_types.go index d386b2d..461fa74 100644 --- a/dsp/shared/contract_types.go +++ b/dsp/shared/contract_types.go @@ -68,12 +68,12 @@ type ContractNegotiationEventMessage struct { // ContractNegotiationTerminationMessage terminates the negotiation. type ContractNegotiationTerminationMessage struct { - Context jsonld.Context `json:"@context"` - Type string `json:"@type" validate:"required,eq=dspace:ContractNegotiationTerminationMessage"` - ProviderPID string `json:"dspace:providerPid" validate:"required"` - ConsumerPID string `json:"dspace:consumerPid" validate:"required"` - Code string `json:"dspace:code"` - Reason []map[string]any `json:"dspace:reason"` + Context jsonld.Context `json:"@context"` + Type string `json:"@type" validate:"required,eq=dspace:ContractNegotiationTerminationMessage"` + ProviderPID string `json:"dspace:providerPid" validate:"required"` + ConsumerPID string `json:"dspace:consumerPid" validate:"required"` + Code string `json:"dspace:code"` + Reason []Multilanguage `json:"dspace:reason"` } // ContractNegotiation is a response to show the state of the contract negotiation. diff --git a/dsp/statemachine/contract_statemachine_test.go b/dsp/statemachine/contract_statemachine_test.go index 6a7c3d0..f115cad 100644 --- a/dsp/statemachine/contract_statemachine_test.go +++ b/dsp/statemachine/contract_statemachine_test.go @@ -569,6 +569,82 @@ func TestContractStateMachineProviderInit(t *testing.T) { validateContract(t, err, pOffNext.GetContract(), statemachine.ContractStates.FINALIZED, false) } +//nolint:funlen +func TestTermination(t *testing.T) { + t.Parallel() + + offer := odrl.Offer{ + MessageOffer: odrl.MessageOffer{ + PolicyClass: odrl.PolicyClass{ + AbstractPolicyRule: odrl.AbstractPolicyRule{}, + ID: uuid.New().URN(), + }, + Type: "odrl:Offer", + Target: target.URN(), + }, + } + + logger := logging.NewJSON("error", true) + ctx := logging.Inject(context.Background(), logger) + ctx, done := context.WithCancel(ctx) + defer done() + + store := statemachine.NewMemoryArchiver() + requester := &MockRequester{} + + mockProvider := mockprovider.NewMockProviderServiceClient(t) + + reconciler := statemachine.NewReconciler(ctx, requester, store) + reconciler.Run() + + for _, role := range []statemachine.DataspaceRole{ + statemachine.DataspaceConsumer, + statemachine.DataspaceProvider, + } { + for _, state := range []statemachine.ContractState{ + statemachine.ContractStates.REQUESTED, + statemachine.ContractStates.OFFERED, + statemachine.ContractStates.ACCEPTED, + statemachine.ContractStates.AGREED, + statemachine.ContractStates.VERIFIED, + } { + consumerPID := uuid.New() + providerPID := uuid.New() + ctx, consumerInit, err := statemachine.NewContract( + ctx, store, mockProvider, reconciler, providerPID, consumerPID, + state, offer, providerCallback, consumerCallback, role) + assert.Nil(t, err) + msg := shared.ContractNegotiationTerminationMessage{ + Context: shared.GetDSPContext(), + Type: "dspace:ContractNegotiationTerminationMessage", + ProviderPID: providerPID.URN(), + ConsumerPID: consumerPID.URN(), + Code: "meh", + Reason: []shared.Multilanguage{ + { + Language: "en", + Value: "test", + }, + }, + } + ctx, next, err := consumerInit.Recv(ctx, msg) + assert.IsType(t, &statemachine.ContractNegotiationTerminated{}, next) + assert.Nil(t, err) + _, err = next.Send(ctx) + assert.Nil(t, err) + var contract *statemachine.Contract + switch role { + case statemachine.DataspaceProvider: + contract, err = store.GetProviderContract(ctx, providerPID) + case statemachine.DataspaceConsumer: + contract, err = store.GetConsumerContract(ctx, consumerPID) + } + assert.Nil(t, err) + assert.Equal(t, statemachine.ContractStates.TERMINATED, contract.GetState()) + } + } +} + func validateContract( t *testing.T, err error, c *statemachine.Contract, state statemachine.ContractState, provInit bool, ) { diff --git a/dsp/statemachine/contract_transitions.go b/dsp/statemachine/contract_transitions.go index 1be4be8..04a3b70 100644 --- a/dsp/statemachine/contract_transitions.go +++ b/dsp/statemachine/contract_transitions.go @@ -220,6 +220,8 @@ func (cn *ContractNegotiationRequested) Recv( ctx, logger = logging.InjectLabels(ctx, "recv_msg_type", fmt.Sprintf("%T", t), ) + case shared.ContractNegotiationTerminationMessage: + return processTermination(ctx, t, cn) default: return ctx, nil, fmt.Errorf("unsupported message type") } @@ -286,6 +288,8 @@ func (cn *ContractNegotiationOffered) Recv( } targetState = receivedStatus logger.Debug("Received message") + case shared.ContractNegotiationTerminationMessage: + return processTermination(ctx, t, cn) default: return ctx, nil, fmt.Errorf("unsupported message type") } @@ -315,16 +319,19 @@ func (cn *ContractNegotiationAccepted) Recv( ) (context.Context, ContractNegotiationState, error) { ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn)) logger.Debug("Receiving message") - m, ok := message.(shared.ContractAgreementMessage) - if !ok { + switch t := message.(type) { + case shared.ContractAgreementMessage: + ctx, logger = logging.InjectLabels(ctx, + "recv_msg_type", fmt.Sprintf("%T", t), + ) + logger.Debug("Received message") + cn.agreement = t.Agreement + return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, t.CallbackAddress, ContractStates.AGREED) + case shared.ContractNegotiationTerminationMessage: + return processTermination(ctx, t, cn) + default: return ctx, nil, fmt.Errorf("unsupported message type") } - ctx, logger = logging.InjectLabels(ctx, - "recv_msg_type", fmt.Sprintf("%T", m), - ) - logger.Debug("Received message") - cn.agreement = m.Agreement - return verifyAndTransform(ctx, cn, m.ProviderPID, m.ConsumerPID, m.CallbackAddress, ContractStates.AGREED) } func (cn *ContractNegotiationAccepted) Send(ctx context.Context) (func(), error) { @@ -341,16 +348,18 @@ func (cn *ContractNegotiationAgreed) Recv( ctx context.Context, message any, ) (context.Context, ContractNegotiationState, error) { ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn)) - logger.Info("Receiving me") - m, ok := message.(shared.ContractAgreementVerificationMessage) - if !ok { + logger.Info("Receiving message") + switch t := message.(type) { + case shared.ContractAgreementVerificationMessage: + ctx, _ = logging.InjectLabels(ctx, + "recv_msg_type", fmt.Sprintf("%T", t), + ) + return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.VERIFIED) + case shared.ContractNegotiationTerminationMessage: + return processTermination(ctx, t, cn) + default: return ctx, nil, fmt.Errorf("unsupported message type") } - ctx, logger = logging.InjectLabels(ctx, - "recv_msg_type", fmt.Sprintf("%T", m), - ) - logger.Debug("Received message") - return verifyAndTransform(ctx, cn, m.ProviderPID, m.ConsumerPID, cn.GetCallback().String(), ContractStates.VERIFIED) } func (cn *ContractNegotiationAgreed) Send(ctx context.Context) (func(), error) { @@ -368,26 +377,29 @@ func (cn *ContractNegotiationVerified) Recv( ) (context.Context, ContractNegotiationState, error) { ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn)) logger.Debug("Receiving message") - m, ok := message.(shared.ContractNegotiationEventMessage) - if !ok { + switch t := message.(type) { + case shared.ContractNegotiationEventMessage: + ctx, logger = logging.InjectLabels(ctx, + "recv_msg_type", fmt.Sprintf("%T", t), + "event_type", t.EventType, + ) + receivedStatus, err := ParseContractState(t.EventType) + if err != nil { + logger.Error("event does not contain the proper status", "err", err) + return ctx, nil, fmt.Errorf("event %s does not contain proper status: %w", t.EventType, err) + } + if receivedStatus != ContractStates.FINALIZED { + logger.Error("invalid status") + return ctx, nil, fmt.Errorf("invalid status: %s", receivedStatus) + } + logger.Debug("Received message") + return verifyAndTransform( + ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.FINALIZED) + case shared.ContractNegotiationTerminationMessage: + return processTermination(ctx, t, cn) + default: return ctx, nil, fmt.Errorf("unsupported message type") } - ctx, logger = logging.InjectLabels(ctx, - "recv_msg_type", fmt.Sprintf("%T", m), - "event_type", m.EventType, - ) - receivedStatus, err := ParseContractState(m.EventType) - if err != nil { - logger.Error("event does not contain the proper status", "err", err) - return ctx, nil, fmt.Errorf("event %s does not contain proper status: %w", m.EventType, err) - } - if receivedStatus != ContractStates.FINALIZED { - logger.Error("invalid status") - return ctx, nil, fmt.Errorf("invalid status: %s", receivedStatus) - } - logger.Debug("Received message") - return verifyAndTransform( - ctx, cn, m.ProviderPID, m.ConsumerPID, cn.GetCallback().String(), ContractStates.FINALIZED) } func (cn *ContractNegotiationVerified) Send(ctx context.Context) (func(), error) { @@ -561,3 +573,15 @@ func verifyAndTransform( ctx, cns := GetContractNegotiation(ctx, cn.GetArchiver(), cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) return ctx, cns, nil } + +func processTermination( + ctx context.Context, t shared.ContractNegotiationTerminationMessage, cn ContractNegotiationState, +) (context.Context, ContractNegotiationState, error) { + logger := logging.Extract(ctx) + logger = logger.With("termination_code", t.Code) + for _, reason := range t.Reason { + logger = logger.With(fmt.Sprintf("reason_%s", reason.Language), reason.Value) + } + ctx = logging.Inject(ctx, logger) + return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.TERMINATED) +}