Skip to content

Commit

Permalink
Handle contract termination. (#54)
Browse files Browse the repository at this point in the history
* Handle contract termination.

This PR adds support for receiving contract termination messages.
It does it by adding a case to handle the termination message to
all the relevant Recv methods and hooking up the Handlers to the
progressContractState function.

* Replace missed code with function
  • Loading branch information
ainmosni authored Aug 30, 2024
1 parent dfaebf2 commit ba7d9f8
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 85 deletions.
53 changes: 8 additions & 45 deletions dsp/contract_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package dsp

import (
"fmt"
"io"
"net/http"
"net/url"
"path"
Expand Down Expand Up @@ -236,29 +235,12 @@ func (dh *dspHandlers) providerContractVerificationHandler(w http.ResponseWriter
)
}

// TODO: Implement terminations.
func (dh *dspHandlers) providerContractTerminationHandler(w http.ResponseWriter, req *http.Request) error {
ctx, logger := logging.InjectLabels(req.Context(), "handler", "providerContractTerminationHandler")
ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractVerificationHandler")
req = req.WithContext(ctx)
providerPID := req.PathValue("providerPID")
if providerPID == "" {
return contractError("missing provider PID", http.StatusBadRequest, "400", "Missing provider PID", nil)
}
reqBody, err := io.ReadAll(req.Body)
if err != nil {
return contractError("could not read body", http.StatusInternalServerError, "500", "could not read body", nil)
}
verification, err := shared.UnmarshalAndValidate(
req.Context(), reqBody, shared.ContractNegotiationTerminationMessage{})
if err != nil {
return contractError("invalid request", http.StatusBadGateway, "400", "invalid request", nil)
}

logger.Debug("Got contract termination", "termination", verification)

// If all goes well, we just return a 200
w.WriteHeader(http.StatusOK)
return nil
return progressContractState[shared.ContractNegotiationTerminationMessage](
dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"),
)
}

func (dh *dspHandlers) consumerContractOfferHandler(w http.ResponseWriter, req *http.Request) error {
Expand Down Expand Up @@ -347,31 +329,12 @@ func (dh *dspHandlers) consumerContractEventHandler(w http.ResponseWriter, req *
)
}

// TODO: Handle termination in the statemachine.
func (dh *dspHandlers) consumerContractTerminationHandler(w http.ResponseWriter, req *http.Request) error {
ctx, logger := logging.InjectLabels(req.Context(), "handler", "consumerContractTerminationHandler")
ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractEventHandler")
req = req.WithContext(ctx)
consumerPID := req.PathValue("consumerPID")
if consumerPID == "" {
return contractError("missing consumer PID", http.StatusBadRequest, "400", "Missing consumer PID", nil)
}
reqBody, err := io.ReadAll(req.Body)
if err != nil {
return contractError("could not read body", http.StatusInternalServerError, "500", "could not read body", nil)
}

// This should have the event FINALIZED
termination, err := shared.UnmarshalAndValidate(req.Context(), reqBody, shared.ContractNegotiationTerminationMessage{})
if err != nil {
return contractError("invalid request", http.StatusBadGateway, "400", "invalid request", nil)
}

logger.Debug("Got contract event", "termination", termination)

// If all goes well, we just return a 200
w.WriteHeader(http.StatusOK)

return nil
return progressContractState[shared.ContractNegotiationTerminationMessage](
dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"),
)
}

func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWriter, req *http.Request) error {
Expand Down
12 changes: 6 additions & 6 deletions dsp/shared/contract_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ type ContractNegotiationEventMessage struct {

// ContractNegotiationTerminationMessage terminates the negotiation.
type ContractNegotiationTerminationMessage struct {
Context jsonld.Context `json:"@context"`
Type string `json:"@type" validate:"required,eq=dspace:ContractNegotiationTerminationMessage"`
ProviderPID string `json:"dspace:providerPid" validate:"required"`
ConsumerPID string `json:"dspace:consumerPid" validate:"required"`
Code string `json:"dspace:code"`
Reason []map[string]any `json:"dspace:reason"`
Context jsonld.Context `json:"@context"`
Type string `json:"@type" validate:"required,eq=dspace:ContractNegotiationTerminationMessage"`
ProviderPID string `json:"dspace:providerPid" validate:"required"`
ConsumerPID string `json:"dspace:consumerPid" validate:"required"`
Code string `json:"dspace:code"`
Reason []Multilanguage `json:"dspace:reason"`
}

// ContractNegotiation is a response to show the state of the contract negotiation.
Expand Down
76 changes: 76 additions & 0 deletions dsp/statemachine/contract_statemachine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,82 @@ func TestContractStateMachineProviderInit(t *testing.T) {
validateContract(t, err, pOffNext.GetContract(), statemachine.ContractStates.FINALIZED, false)
}

//nolint:funlen
func TestTermination(t *testing.T) {
t.Parallel()

offer := odrl.Offer{
MessageOffer: odrl.MessageOffer{
PolicyClass: odrl.PolicyClass{
AbstractPolicyRule: odrl.AbstractPolicyRule{},
ID: uuid.New().URN(),
},
Type: "odrl:Offer",
Target: target.URN(),
},
}

logger := logging.NewJSON("error", true)
ctx := logging.Inject(context.Background(), logger)
ctx, done := context.WithCancel(ctx)
defer done()

store := statemachine.NewMemoryArchiver()
requester := &MockRequester{}

mockProvider := mockprovider.NewMockProviderServiceClient(t)

reconciler := statemachine.NewReconciler(ctx, requester, store)
reconciler.Run()

for _, role := range []statemachine.DataspaceRole{
statemachine.DataspaceConsumer,
statemachine.DataspaceProvider,
} {
for _, state := range []statemachine.ContractState{
statemachine.ContractStates.REQUESTED,
statemachine.ContractStates.OFFERED,
statemachine.ContractStates.ACCEPTED,
statemachine.ContractStates.AGREED,
statemachine.ContractStates.VERIFIED,
} {
consumerPID := uuid.New()
providerPID := uuid.New()
ctx, consumerInit, err := statemachine.NewContract(
ctx, store, mockProvider, reconciler, providerPID, consumerPID,
state, offer, providerCallback, consumerCallback, role)
assert.Nil(t, err)
msg := shared.ContractNegotiationTerminationMessage{
Context: shared.GetDSPContext(),
Type: "dspace:ContractNegotiationTerminationMessage",
ProviderPID: providerPID.URN(),
ConsumerPID: consumerPID.URN(),
Code: "meh",
Reason: []shared.Multilanguage{
{
Language: "en",
Value: "test",
},
},
}
ctx, next, err := consumerInit.Recv(ctx, msg)
assert.IsType(t, &statemachine.ContractNegotiationTerminated{}, next)
assert.Nil(t, err)
_, err = next.Send(ctx)
assert.Nil(t, err)
var contract *statemachine.Contract
switch role {
case statemachine.DataspaceProvider:
contract, err = store.GetProviderContract(ctx, providerPID)
case statemachine.DataspaceConsumer:
contract, err = store.GetConsumerContract(ctx, consumerPID)
}
assert.Nil(t, err)
assert.Equal(t, statemachine.ContractStates.TERMINATED, contract.GetState())
}
}
}

func validateContract(
t *testing.T, err error, c *statemachine.Contract, state statemachine.ContractState, provInit bool,
) {
Expand Down
92 changes: 58 additions & 34 deletions dsp/statemachine/contract_transitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ func (cn *ContractNegotiationRequested) Recv(
ctx, logger = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", t),
)
case shared.ContractNegotiationTerminationMessage:
return processTermination(ctx, t, cn)
default:
return ctx, nil, fmt.Errorf("unsupported message type")
}
Expand Down Expand Up @@ -286,6 +288,8 @@ func (cn *ContractNegotiationOffered) Recv(
}
targetState = receivedStatus
logger.Debug("Received message")
case shared.ContractNegotiationTerminationMessage:
return processTermination(ctx, t, cn)
default:
return ctx, nil, fmt.Errorf("unsupported message type")
}
Expand Down Expand Up @@ -315,16 +319,19 @@ func (cn *ContractNegotiationAccepted) Recv(
) (context.Context, ContractNegotiationState, error) {
ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn))
logger.Debug("Receiving message")
m, ok := message.(shared.ContractAgreementMessage)
if !ok {
switch t := message.(type) {
case shared.ContractAgreementMessage:
ctx, logger = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", t),
)
logger.Debug("Received message")
cn.agreement = t.Agreement
return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, t.CallbackAddress, ContractStates.AGREED)
case shared.ContractNegotiationTerminationMessage:
return processTermination(ctx, t, cn)
default:
return ctx, nil, fmt.Errorf("unsupported message type")
}
ctx, logger = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", m),
)
logger.Debug("Received message")
cn.agreement = m.Agreement
return verifyAndTransform(ctx, cn, m.ProviderPID, m.ConsumerPID, m.CallbackAddress, ContractStates.AGREED)
}

func (cn *ContractNegotiationAccepted) Send(ctx context.Context) (func(), error) {
Expand All @@ -341,16 +348,18 @@ func (cn *ContractNegotiationAgreed) Recv(
ctx context.Context, message any,
) (context.Context, ContractNegotiationState, error) {
ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn))
logger.Info("Receiving me")
m, ok := message.(shared.ContractAgreementVerificationMessage)
if !ok {
logger.Info("Receiving message")
switch t := message.(type) {
case shared.ContractAgreementVerificationMessage:
ctx, _ = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", t),
)
return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.VERIFIED)
case shared.ContractNegotiationTerminationMessage:
return processTermination(ctx, t, cn)
default:
return ctx, nil, fmt.Errorf("unsupported message type")
}
ctx, logger = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", m),
)
logger.Debug("Received message")
return verifyAndTransform(ctx, cn, m.ProviderPID, m.ConsumerPID, cn.GetCallback().String(), ContractStates.VERIFIED)
}

func (cn *ContractNegotiationAgreed) Send(ctx context.Context) (func(), error) {
Expand All @@ -368,26 +377,29 @@ func (cn *ContractNegotiationVerified) Recv(
) (context.Context, ContractNegotiationState, error) {
ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn))
logger.Debug("Receiving message")
m, ok := message.(shared.ContractNegotiationEventMessage)
if !ok {
switch t := message.(type) {
case shared.ContractNegotiationEventMessage:
ctx, logger = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", t),
"event_type", t.EventType,
)
receivedStatus, err := ParseContractState(t.EventType)
if err != nil {
logger.Error("event does not contain the proper status", "err", err)
return ctx, nil, fmt.Errorf("event %s does not contain proper status: %w", t.EventType, err)
}
if receivedStatus != ContractStates.FINALIZED {
logger.Error("invalid status")
return ctx, nil, fmt.Errorf("invalid status: %s", receivedStatus)
}
logger.Debug("Received message")
return verifyAndTransform(
ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.FINALIZED)
case shared.ContractNegotiationTerminationMessage:
return processTermination(ctx, t, cn)
default:
return ctx, nil, fmt.Errorf("unsupported message type")
}
ctx, logger = logging.InjectLabels(ctx,
"recv_msg_type", fmt.Sprintf("%T", m),
"event_type", m.EventType,
)
receivedStatus, err := ParseContractState(m.EventType)
if err != nil {
logger.Error("event does not contain the proper status", "err", err)
return ctx, nil, fmt.Errorf("event %s does not contain proper status: %w", m.EventType, err)
}
if receivedStatus != ContractStates.FINALIZED {
logger.Error("invalid status")
return ctx, nil, fmt.Errorf("invalid status: %s", receivedStatus)
}
logger.Debug("Received message")
return verifyAndTransform(
ctx, cn, m.ProviderPID, m.ConsumerPID, cn.GetCallback().String(), ContractStates.FINALIZED)
}

func (cn *ContractNegotiationVerified) Send(ctx context.Context) (func(), error) {
Expand Down Expand Up @@ -561,3 +573,15 @@ func verifyAndTransform(
ctx, cns := GetContractNegotiation(ctx, cn.GetArchiver(), cn.GetContract(), cn.GetProvider(), cn.GetReconciler())
return ctx, cns, nil
}

func processTermination(
ctx context.Context, t shared.ContractNegotiationTerminationMessage, cn ContractNegotiationState,
) (context.Context, ContractNegotiationState, error) {
logger := logging.Extract(ctx)
logger = logger.With("termination_code", t.Code)
for _, reason := range t.Reason {
logger = logger.With(fmt.Sprintf("reason_%s", reason.Language), reason.Value)
}
ctx = logging.Inject(ctx, logger)
return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.TERMINATED)
}

0 comments on commit ba7d9f8

Please sign in to comment.