diff --git a/dsp/common_handlers.go b/dsp/common_handlers.go index 79913ec..751cbea 100644 --- a/dsp/common_handlers.go +++ b/dsp/common_handlers.go @@ -20,6 +20,7 @@ import ( "net/http" "net/url" + "github.com/go-dataspace/run-dsp/dsp/persistence" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/dsp/statemachine" "github.com/go-dataspace/run-dsp/internal/constants" @@ -29,7 +30,7 @@ import ( ) type dspHandlers struct { - store statemachine.Archiver + store persistence.StorageProvider provider providerv1.ProviderServiceClient reconciler *statemachine.Reconciler selfURL *url.URL diff --git a/dsp/constants/roles.go b/dsp/constants/roles.go new file mode 100644 index 0000000..feff0fb --- /dev/null +++ b/dsp/constants/roles.go @@ -0,0 +1,23 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package constants + +// DataspaceRole signifies what role in a dataspace exchange this is. +type DataspaceRole uint8 + +const ( + DataspaceConsumer DataspaceRole = iota + DataspaceProvider +) diff --git a/dsp/contract/doc.go b/dsp/contract/doc.go new file mode 100644 index 0000000..b7682b0 --- /dev/null +++ b/dsp/contract/doc.go @@ -0,0 +1,16 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package contract contains the Contract type and all related code. +package contract diff --git a/dsp/contract/negotiation.go b/dsp/contract/negotiation.go new file mode 100644 index 0000000..6d34bd1 --- /dev/null +++ b/dsp/contract/negotiation.go @@ -0,0 +1,241 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package contract + +import ( + "bytes" + "encoding/gob" + "fmt" + "net/url" + "slices" + "strconv" + + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/shared" + "github.com/go-dataspace/run-dsp/odrl" + "github.com/google/uuid" +) + +var validTransitions = map[State][]State{ + States.INITIAL: { + States.OFFERED, + States.REQUESTED, + States.TERMINATED, + }, + States.REQUESTED: { + States.OFFERED, + States.AGREED, + States.TERMINATED, + }, + States.OFFERED: { + States.REQUESTED, + States.ACCEPTED, + States.TERMINATED, + }, + States.ACCEPTED: { + States.AGREED, + States.TERMINATED, + }, + States.AGREED: { + States.VERIFIED, + States.TERMINATED, + }, + States.VERIFIED: { + States.FINALIZED, + States.TERMINATED, + }, + States.FINALIZED: {}, + States.TERMINATED: {}, +} + +// Negotiation represents a contract negotiation. +type Negotiation struct { + providerPID uuid.UUID + consumerPID uuid.UUID + state State + offer odrl.Offer + agreement *odrl.Agreement + callback *url.URL + self *url.URL + role constants.DataspaceRole + + initial bool + ro bool + modified bool +} + +type storableNegotiation struct { + ProviderPID uuid.UUID + ConsumerPID uuid.UUID + State State + Offer odrl.Offer + Agreement *odrl.Agreement + Callback *url.URL + Self *url.URL + Role constants.DataspaceRole +} + +func New( + providerPID, consumerPID uuid.UUID, + state State, + offer odrl.Offer, + callback, self *url.URL, + role constants.DataspaceRole, +) *Negotiation { + return &Negotiation{ + providerPID: providerPID, + consumerPID: consumerPID, + state: state, + offer: offer, + callback: callback, + self: self, + role: role, + modified: true, + } +} + +func FromBytes(b []byte) (*Negotiation, error) { + var sn storableNegotiation + r := bytes.NewReader(b) + dec := gob.NewDecoder(r) + if err := dec.Decode(&sn); err != nil { + return nil, fmt.Errorf("Could not decode bytes into storableNegotiation: %w", err) + } + return &Negotiation{ + providerPID: sn.ProviderPID, + consumerPID: sn.ConsumerPID, + state: sn.State, + offer: sn.Offer, + agreement: sn.Agreement, + callback: sn.Callback, + self: sn.Self, + role: sn.Role, + }, nil +} + +// GenerateKey generates a key for a contract negotiation. +func GenerateKey(id uuid.UUID, role constants.DataspaceRole) []byte { + return []byte("negotiation-" + id.String() + "-" + strconv.Itoa(int(role))) +} + +// Negotiation getters. +func (cn *Negotiation) GetProviderPID() uuid.UUID { return cn.providerPID } +func (cn *Negotiation) GetConsumerPID() uuid.UUID { return cn.consumerPID } +func (cn *Negotiation) GetState() State { return cn.state } +func (cn *Negotiation) GetOffer() odrl.Offer { return cn.offer } +func (cn *Negotiation) GetAgreement() *odrl.Agreement { return cn.agreement } +func (cn *Negotiation) GetRole() constants.DataspaceRole { return cn.role } +func (cn *Negotiation) GetCallback() *url.URL { return cn.callback } +func (cn *Negotiation) GetSelf() *url.URL { return cn.self } +func (cn *Negotiation) GetContract() *Negotiation { return cn } + +// Negotiation setters, these will panic when the negotiation is RO. +func (cn *Negotiation) SetProviderPID(u uuid.UUID) { + cn.panicRO() + cn.providerPID = u + cn.modify() +} + +func (cn *Negotiation) SetConsumerPID(u uuid.UUID) { + cn.panicRO() + cn.providerPID = u + cn.modify() +} + +func (cn *Negotiation) SetAgreement(a *odrl.Agreement) { + cn.panicRO() + cn.agreement = a + cn.modify() +} + +func (cn *Negotiation) SetState(state State) error { + cn.panicRO() + if !slices.Contains(validTransitions[cn.state], state) { + return fmt.Errorf("can't transition from %s to %s", cn.state, state) + } + cn.state = state + cn.modify() + return nil +} + +// SetCallback sets the remote callback root. +func (cn *Negotiation) SetCallback(u string) error { + nu, err := url.Parse(u) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + cn.callback = nu + cn.modify() + return nil +} + +// Properties that decisions are based on. +func (cn *Negotiation) ReadOnly() bool { return cn.ro } +func (cn *Negotiation) Initial() bool { return cn.initial } +func (cn *Negotiation) Modified() bool { return cn.modified } +func (cn *Negotiation) StorageKey() []byte { + id := cn.consumerPID + if cn.role == constants.DataspaceProvider { + id = cn.providerPID + } + return GenerateKey(id, cn.role) +} + +// Property setters. +func (cn *Negotiation) SetReadOnly() { cn.ro = true } +func (cn *Negotiation) SetInitial() { cn.initial = true } +func (cn *Negotiation) UnsetInitial() { cn.initial = false } + +// ToBytes returns a binary representation of the negotiation, one that is compatible with the FromBytes +// function. +func (cn *Negotiation) ToBytes() ([]byte, error) { + s := storableNegotiation{ + ProviderPID: cn.providerPID, + ConsumerPID: cn.consumerPID, + State: cn.state, + Offer: cn.offer, + Agreement: cn.agreement, + Callback: cn.callback, + Self: cn.self, + Role: cn.role, + } + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(s); err != nil { + return nil, fmt.Errorf("could not encode negotiation: %w", err) + } + return buf.Bytes(), nil +} + +// GetContractNegotiation returns a ContractNegotion message. +func (cn *Negotiation) GetContractNegotiation() shared.ContractNegotiation { + return shared.ContractNegotiation{ + Context: shared.GetDSPContext(), + Type: "dspace:ContractNegotiation", + ConsumerPID: cn.GetConsumerPID().URN(), + ProviderPID: cn.GetProviderPID().URN(), + State: cn.GetState().String(), + } +} + +func (cn *Negotiation) panicRO() { + if cn.ro { + panic("Trying to write to a read-only negotiation, this is certainly a bug.") + } +} + +func (cn *Negotiation) modify() { + cn.modified = true +} diff --git a/dsp/statemachine/contract_state.go b/dsp/contract/state.go similarity index 58% rename from dsp/statemachine/contract_state.go rename to dsp/contract/state.go index 51d684f..73bcaba 100644 --- a/dsp/statemachine/contract_state.go +++ b/dsp/contract/state.go @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package statemachine +package contract -type contractState int +type state int //go:generate goenums contract_state.go const ( - initial contractState = iota // INITIAL - requested // dspace:REQUESTED - offered // dspace:OFFERED - agreed // dspace:AGREED - accepted // dspace:ACCEPTED - verified // dspace:VERIFIED - finalized // dspace:FINALIZED - terminated // dspace:TERMINATED + initial state = iota // INITIAL + requested // dspace:REQUESTED + offered // dspace:OFFERED + agreed // dspace:AGREED + accepted // dspace:ACCEPTED + verified // dspace:VERIFIED + finalized // dspace:FINALIZED + terminated // dspace:TERMINATED ) diff --git a/dsp/contract/states_enums.go b/dsp/contract/states_enums.go new file mode 100644 index 0000000..05033a6 --- /dev/null +++ b/dsp/contract/states_enums.go @@ -0,0 +1,194 @@ +// Code generated by goenums. DO NOT EDIT. +// This file was generated by github.com/zarldev/goenums +// using the command: +// goenums ./dsp/contract/state.go + +package contract + +import ( + "bytes" + "database/sql/driver" + "fmt" + "strconv" +) + +type State struct { + state +} + +type statesContainer struct { + INITIAL State + REQUESTED State + OFFERED State + AGREED State + ACCEPTED State + VERIFIED State + FINALIZED State + TERMINATED State +} + +var States = statesContainer{ + INITIAL: State{ + state: initial, + }, + REQUESTED: State{ + state: requested, + }, + OFFERED: State{ + state: offered, + }, + AGREED: State{ + state: agreed, + }, + ACCEPTED: State{ + state: accepted, + }, + VERIFIED: State{ + state: verified, + }, + FINALIZED: State{ + state: finalized, + }, + TERMINATED: State{ + state: terminated, + }, +} + +func (c statesContainer) All() []State { + return []State{ + c.INITIAL, + c.REQUESTED, + c.OFFERED, + c.AGREED, + c.ACCEPTED, + c.VERIFIED, + c.FINALIZED, + c.TERMINATED, + } +} + +var invalidState = State{} + +func ParseState(a any) (State, error) { + res := invalidState + switch v := a.(type) { + case State: + return v, nil + case []byte: + res = stringToState(string(v)) + case string: + res = stringToState(v) + case fmt.Stringer: + res = stringToState(v.String()) + case int: + res = intToState(v) + case int64: + res = intToState(int(v)) + case int32: + res = intToState(int(v)) + } + return res, nil +} + +func stringToState(s string) State { + switch s { + case "INITIAL": + return States.INITIAL + case "dspace:REQUESTED": + return States.REQUESTED + case "dspace:OFFERED": + return States.OFFERED + case "dspace:AGREED": + return States.AGREED + case "dspace:ACCEPTED": + return States.ACCEPTED + case "dspace:VERIFIED": + return States.VERIFIED + case "dspace:FINALIZED": + return States.FINALIZED + case "dspace:TERMINATED": + return States.TERMINATED + } + return invalidState +} + +func intToState(i int) State { + if i < 0 || i >= len(States.All()) { + return invalidState + } + return States.All()[i] +} + +func ExhaustiveStates(f func(State)) { + for _, p := range States.All() { + f(p) + } +} + +var validStates = map[State]bool{ + States.INITIAL: true, + States.REQUESTED: true, + States.OFFERED: true, + States.AGREED: true, + States.ACCEPTED: true, + States.VERIFIED: true, + States.FINALIZED: true, + States.TERMINATED: true, +} + +func (p State) IsValid() bool { + return validStates[p] +} + +func (p State) MarshalJSON() ([]byte, error) { + return []byte(`"` + p.String() + `"`), nil +} + +func (p *State) UnmarshalJSON(b []byte) error { + b = bytes.Trim(bytes.Trim(b, `"`), ` `) + newp, err := ParseState(b) + if err != nil { + return err + } + *p = newp + return nil +} + +func (p *State) Scan(value any) error { + newp, err := ParseState(value) + if err != nil { + return err + } + *p = newp + return nil +} + +func (p State) Value() (driver.Value, error) { + return p.String(), nil +} + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the goenums command to generate them again. + // Does not identify newly added constant values unless order changes + var x [1]struct{} + _ = x[initial-0] + _ = x[requested-1] + _ = x[offered-2] + _ = x[agreed-3] + _ = x[accepted-4] + _ = x[verified-5] + _ = x[finalized-6] + _ = x[terminated-7] +} + +const _states_name = "INITIALdspace:REQUESTEDdspace:OFFEREDdspace:AGREEDdspace:ACCEPTEDdspace:VERIFIEDdspace:FINALIZEDdspace:TERMINATED" + +var _states_index = [...]uint16{0, 7, 23, 37, 50, 65, 80, 96, 113} + +func (i state) String() string { + if i < 0 || i >= state(len(_states_index)-1) { + return "states(" + (strconv.FormatInt(int64(i), 10) + ")") + } + return _states_name[_states_index[i]:_states_index[i+1]] +} diff --git a/dsp/contract/states_gob.go b/dsp/contract/states_gob.go new file mode 100644 index 0000000..adcdafd --- /dev/null +++ b/dsp/contract/states_gob.go @@ -0,0 +1,15 @@ +package contract + +func (p State) GobEncode() ([]byte, error) { + return []byte(p.String()), nil +} + +func (p *State) GobDecode(b []byte) error { + newp, err := ParseState(b) + if err != nil { + return err + } + + *p = newp + return nil +} diff --git a/dsp/contract_handlers.go b/dsp/contract_handlers.go index 92b7c73..2c97659 100644 --- a/dsp/contract_handlers.go +++ b/dsp/contract_handlers.go @@ -15,11 +15,15 @@ package dsp import ( + "context" "fmt" "net/http" "net/url" "path" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" + "github.com/go-dataspace/run-dsp/dsp/persistence" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/dsp/statemachine" "github.com/go-dataspace/run-dsp/logging" @@ -29,7 +33,7 @@ import ( type ContractError struct { status int - contract *statemachine.Contract + contract *contract.Negotiation dspCode string reason string err string @@ -63,7 +67,7 @@ func (ce ContractError) ConsumerPID() string { } func contractError( - err string, statusCode int, dspCode string, reason string, contract *statemachine.Contract, + err string, statusCode int, dspCode string, reason string, contract *contract.Negotiation, ) ContractError { return ContractError{ status: statusCode, @@ -81,7 +85,7 @@ func (dh *dspHandlers) providerContractStateHandler(w http.ResponseWriter, req * return contractError("invalid provider ID", http.StatusBadRequest, "400", "Invalid provider PID", nil) } - contract, err := dh.store.GetProviderContract(req.Context(), providerPID) + contract, err := dh.store.GetContractR(req.Context(), providerPID, constants.DataspaceProvider) if err != nil { return contractError(err.Error(), http.StatusNotFound, "404", "Contract not found", nil) } @@ -115,47 +119,25 @@ func (dh *dspHandlers) providerContractRequestHandler(w http.ResponseWriter, req http.StatusBadRequest, "400", "Invalid request: Non-valid callback URL.", nil) } - // TODO: Maybe make a function in the statemachine that parses the contract request message. - ctx, pState, err := statemachine.NewContract( - req.Context(), - dh.store, dh.provider, dh.reconciler, - uuid.UUID{}, consumerPID, - statemachine.ContractStates.INITIAL, + negotiation := contract.New( + uuid.UUID{}, + consumerPID, + contract.States.INITIAL, odrl.Offer{MessageOffer: contractReq.Offer}, - cbURL, dh.selfURL, - statemachine.DataspaceProvider, + cbURL, + dh.selfURL, + constants.DataspaceProvider, ) - req = req.WithContext(ctx) - if err != nil { - return contractError(fmt.Sprintf("couldn't create contract: %s", err.Error()), - http.StatusInternalServerError, "500", "Failed to create contract", nil) + if err := storeNegotiation(ctx, dh.store, negotiation); err != nil { + return err } - ctx, nextState, err := pState.Recv(req.Context(), contractReq) - if err != nil { - return contractError( - fmt.Sprintf("couldn't receive message: %s", err.Error()), - http.StatusBadRequest, "400", "Invalid request", pState.GetContract(), - ) - } - req = req.WithContext(ctx) - - apply, err := nextState.Send(req.Context()) - if err != nil { - return contractError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), - http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetContract()) - } - - if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetContractNegotiation()); err != nil { - logger.Error("Couldn't serve response", "err", err) - } - go apply() - return nil + return processMessage(dh, w, req, negotiation.GetRole(), negotiation.GetProviderPID(), contractReq) } func progressContractState[T any]( - dh *dspHandlers, w http.ResponseWriter, req *http.Request, role statemachine.DataspaceRole, rawPID string, + dh *dspHandlers, w http.ResponseWriter, req *http.Request, role constants.DataspaceRole, rawPID string, ) error { logger := logging.Extract(req.Context()) pid, err := uuid.Parse(rawPID) @@ -163,45 +145,54 @@ func progressContractState[T any]( return contractError(fmt.Sprintf("Invalid PID %s: %s", rawPID, err.Error()), http.StatusBadRequest, "400", "Invalid request: PID is not a UUID", nil) } - - var contract *statemachine.Contract - switch role { - case statemachine.DataspaceConsumer: - contract, err = dh.store.GetConsumerContract(req.Context(), pid) - case statemachine.DataspaceProvider: - contract, err = dh.store.GetProviderContract(req.Context(), pid) - default: - panic(fmt.Sprintf("unexpected statemachine.ContractRole: %#v", role)) - } - if err != nil { - return contractError(fmt.Sprintf("%d contract %s not found: %s", role, pid, err), - http.StatusNotFound, "404", "Contract not found", nil) - } - msg, err := shared.DecodeValid[T](req) if err != nil { return contractError(fmt.Sprintf("could not decode message: %s", err), - http.StatusBadRequest, "400", "Invalid request", contract) + http.StatusBadRequest, "400", "Invalid request", nil) } logger.Debug("Got contract message", "req", msg) - ctx, pState := statemachine.GetContractNegotiation(req.Context(), dh.store, contract, dh.provider, dh.reconciler) - logger = logging.Extract(ctx) - req = req.WithContext(ctx) + return processMessage(dh, w, req, role, pid, msg) +} - ctx, nextState, err := pState.Recv(req.Context(), msg) +func processMessage[T any]( + dh *dspHandlers, + w http.ResponseWriter, + req *http.Request, + role constants.DataspaceRole, + pid uuid.UUID, + msg T, +) error { + logger := logging.Extract(req.Context()) + contract, err := dh.store.GetContractRW(req.Context(), pid, role) + if err != nil { + return contractError(fmt.Sprintf("%d contract %s not found: %s", role, pid, err), + http.StatusNotFound, "404", "Contract not found", nil) + } + + ctx, pState := statemachine.GetContractNegotiation( + req.Context(), + contract, + dh.provider, + dh.reconciler, + ) + + ctx, nextState, err := pState.Recv(ctx, msg) if err != nil { return contractError(fmt.Sprintf("invalid request: %s", err), http.StatusBadRequest, "400", "Invalid request", pState.GetContract()) } - req = req.WithContext(ctx) - apply, err := nextState.Send(req.Context()) + apply, err := nextState.Send(ctx) if err != nil { return contractError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetContract()) } + err = storeNegotiation(ctx, dh.store, nextState.GetContract()) + if err != nil { + return err + } if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetContractNegotiation()); err != nil { logger.Error("Couldn't serve response", "err", err) @@ -211,11 +202,30 @@ func progressContractState[T any]( return nil } +func storeNegotiation( + ctx context.Context, + store persistence.StorageProvider, + negotiation *contract.Negotiation, +) error { + if err := store.PutContract(ctx, negotiation); err != nil { + return contractError(fmt.Sprintf("couldn't store negotiation: %s", err), + http.StatusInternalServerError, "500", "Not able to store negotiation", negotiation) + } + + if negotiation.Modified() && negotiation.GetAgreement() != nil { + if err := store.PutAgreement(ctx, negotiation.GetAgreement()); err != nil { + return contractError(fmt.Sprintf("couldn't store agreement: %s", err), + http.StatusInternalServerError, "500", "Not able to store agreement", negotiation) + } + } + return nil +} + func (dh *dspHandlers) providerContractSpecificRequestHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractSpecificRequestHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractRequestMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), ) } @@ -223,7 +233,7 @@ func (dh *dspHandlers) providerContractEventHandler(w http.ResponseWriter, req * ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractEventHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractNegotiationEventMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), ) } @@ -231,7 +241,7 @@ func (dh *dspHandlers) providerContractVerificationHandler(w http.ResponseWriter ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractVerificationHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractAgreementVerificationMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), ) } @@ -239,7 +249,7 @@ func (dh *dspHandlers) providerContractTerminationHandler(w http.ResponseWriter, ctx, _ := logging.InjectLabels(req.Context(), "handler", "providerContractVerificationHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractNegotiationTerminationMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), ) } @@ -270,46 +280,28 @@ func (dh *dspHandlers) consumerContractOfferHandler(w http.ResponseWriter, req * panic(err.Error()) } selfURL.Path = path.Join(selfURL.Path, "callback") - ctx, cState, err := statemachine.NewContract( - req.Context(), - dh.store, dh.provider, dh.reconciler, - providerPID, uuid.UUID{}, - statemachine.ContractStates.INITIAL, + + negotiation := contract.New( + providerPID, + uuid.UUID{}, + contract.States.INITIAL, odrl.Offer{MessageOffer: contractOffer.Offer}, - cbURL, selfURL, statemachine.DataspaceConsumer, + cbURL, + selfURL, + constants.DataspaceConsumer, ) - logger = logging.Extract(ctx) - req = req.WithContext(ctx) - if err != nil { - return contractError(fmt.Sprintf("couldn't create contract: %s", err.Error()), - http.StatusInternalServerError, "500", "Failed to create contract", nil) - } - - ctx, nextState, err := cState.Recv(req.Context(), contractOffer) - if err != nil { - return contractError(fmt.Sprintf("couldn't receive message: %s", err.Error()), - http.StatusBadRequest, "400", "Invalid request", cState.GetContract()) - } - req = req.WithContext(ctx) - - apply, err := nextState.Send(req.Context()) - if err != nil { - return contractError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), - http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetContract()) + if err := storeNegotiation(ctx, dh.store, negotiation); err != nil { + return err } - if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetContractNegotiation()); err != nil { - logger.Error("Couldn't serve response", "err", err) - } - go apply() - return nil + return processMessage(dh, w, req, negotiation.GetRole(), negotiation.GetConsumerPID(), contractOffer) } func (dh *dspHandlers) consumerContractSpecificOfferHandler(w http.ResponseWriter, req *http.Request) error { ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractSpecificOfferHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractOfferMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), ) } @@ -317,7 +309,7 @@ func (dh *dspHandlers) consumerContractAgreementHandler(w http.ResponseWriter, r ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractAgreementHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractAgreementMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), ) } @@ -325,7 +317,7 @@ func (dh *dspHandlers) consumerContractEventHandler(w http.ResponseWriter, req * ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractEventHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractNegotiationEventMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), ) } @@ -333,6 +325,6 @@ func (dh *dspHandlers) consumerContractTerminationHandler(w http.ResponseWriter, ctx, _ := logging.InjectLabels(req.Context(), "handler", "consumerContractEventHandler") req = req.WithContext(ctx) return progressContractState[shared.ContractNegotiationTerminationMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), ) } diff --git a/dsp/control/control.go b/dsp/control/control.go index 378f488..aed19e2 100644 --- a/dsp/control/control.go +++ b/dsp/control/control.go @@ -21,8 +21,12 @@ import ( "strings" "time" + dspconstants "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" + "github.com/go-dataspace/run-dsp/dsp/persistence" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/dsp/statemachine" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/internal/constants" "github.com/go-dataspace/run-dsp/jsonld" "github.com/go-dataspace/run-dsp/logging" @@ -40,7 +44,7 @@ type Server struct { dspv1alpha1.ClientServiceServer requester shared.Requester - store statemachine.Archiver + store persistence.StorageProvider reconciler *statemachine.Reconciler provider dspv1alpha1.ProviderServiceClient selfURL *url.URL @@ -48,7 +52,7 @@ type Server struct { func New( requester shared.Requester, - store statemachine.Archiver, + store persistence.StorageProvider, reconciler *statemachine.Reconciler, provider dspv1alpha1.ProviderServiceClient, selfURL *url.URL, @@ -173,11 +177,9 @@ func (s *Server) GetProviderDatasetDownloadInformation( consumerPID := uuid.New() selfURL := shared.MustParseURL(s.selfURL.String()) selfURL.Path = path.Join(selfURL.Path, "callback") - ctx, contractInit, err := statemachine.NewContract( - ctx, - s.store, s.provider, s.reconciler, + negotiation := contract.New( uuid.UUID{}, consumerPID, - statemachine.ContractStates.INITIAL, + contract.States.INITIAL, odrl.Offer{ MessageOffer: odrl.MessageOffer{ PolicyClass: odrl.PolicyClass{ @@ -190,33 +192,44 @@ func (s *Server) GetProviderDatasetDownloadInformation( }, providerURL, selfURL, - statemachine.DataspaceConsumer, + dspconstants.DataspaceConsumer, ) - ctx, logger := logging.InjectLabels(ctx, "method", "GetProviderDownloadInformationRequest") + // Store and retrieve contract negotiation so that it's saved and the locking works. + if err := s.store.PutContract(ctx, negotiation); err != nil { + return nil, status.Errorf(codes.Internal, "couldn't store contract negotiation: %s", err) + } + negotiation, err = s.store.GetContractRW(ctx, negotiation.GetConsumerPID(), negotiation.GetRole()) if err != nil { - return nil, status.Errorf(codes.Internal, "Couldn't create contract") + return nil, status.Errorf(codes.Internal, "couldn't retrieve contract negotiation: %s", err) } + + ctx, contractInit := statemachine.GetContractNegotiation(ctx, negotiation, s.provider, s.reconciler) + ctx, logger := logging.InjectLabels(ctx, "method", "GetProviderDownloadInformationRequest") + apply, err := contractInit.Send(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "Couldn't generate inital contract request.") } + if err := s.store.PutContract(ctx, negotiation); err != nil { + return nil, status.Errorf(codes.Internal, "Couldn't store contract negotiation: %s", err) + } logger.Debug("Beginning contract negotiation") apply() - contract, err := s.store.GetConsumerContract(ctx, consumerPID) + negotiation, err = s.store.GetContractR(ctx, consumerPID, dspconstants.DataspaceConsumer) if err != nil { return nil, status.Errorf(codes.Internal, "could not get consumer contract with PID %s: %s", consumerPID, err) } logger.Info("Starting to monitor contract") checks := 0 - for contract.GetState() != statemachine.ContractStates.FINALIZED { + for negotiation.GetState() != contract.States.FINALIZED { // Only log the status every 10 checks. if checks%10 == 0 { - logger.Info("Contract not finalized", "state", contract.GetState().String()) + logger.Info("Contract not finalized", "state", negotiation.GetState().String()) } time.Sleep(1 * time.Second) - contract, err = s.store.GetConsumerContract(ctx, consumerPID) + negotiation, err = s.store.GetContractR(ctx, consumerPID, dspconstants.DataspaceConsumer) if err != nil { return nil, status.Errorf(codes.Internal, "could not get consumer contract with PID %s: %s", consumerPID, err) } @@ -224,40 +237,53 @@ func (s *Server) GetProviderDatasetDownloadInformation( } logger.Info("Contract finalized, continuing") transferConsumerPID := uuid.New() - agreementID := uuid.MustParse(contract.GetAgreement().ID) + agreementID := uuid.MustParse(negotiation.GetAgreement().ID) + agreement, err := s.store.GetAgreement(ctx, agreementID) + if err != nil { + return nil, status.Errorf(codes.Internal, "could not get agreement with ID %s: %s", agreementID, err) + } - transferInit, err := statemachine.NewTransferRequest( - ctx, - s.store, s.provider, s.reconciler, + transferReq := transfer.New( transferConsumerPID, - agreementID, + agreement, "HTTP_PULL", providerURL, selfURL, - statemachine.DataspaceConsumer, - statemachine.TransferRequestStates.TRANSFERINITIAL, + dspconstants.DataspaceConsumer, + transfer.States.INITIAL, nil, ) + // Save and retrieve the transfer request to get the locks working properly. + if err := s.store.PutTransfer(ctx, transferReq); err != nil { + return nil, status.Errorf(codes.Internal, "Couldn't create transfer request: %s", err) + } + transferReq, err = s.store.GetTransferRW(ctx, transferReq.GetConsumerPID(), dspconstants.DataspaceConsumer) if err != nil { - return nil, status.Errorf(codes.Internal, "Couldn't create transfer request") + return nil, status.Errorf(codes.Internal, "could not retrieve transfer request: %s", err) } + transferInit := statemachine.GetTransferRequestNegotiation(transferReq, s.provider, s.reconciler) + apply, err = transferInit.Send(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "Couldn't generate inital initial request.") } + if err := s.store.PutTransfer(ctx, transferReq); err != nil { + return nil, status.Errorf(codes.Internal, "Couldn't create transfer request: %s", err) + } + logger.Debug("Beginning transfer request") apply() - tReq, err := s.store.GetConsumerTransfer(ctx, transferConsumerPID) + tReq, err := s.store.GetTransferR(ctx, transferConsumerPID, dspconstants.DataspaceConsumer) if err != nil { return nil, status.Errorf(codes.Internal, "could not get consumer transfer with PID %s: %s", transferConsumerPID, err) } logger.Info("Starting to monitor transfer request") - for tReq.GetState() != statemachine.TransferRequestStates.STARTED { + for tReq.GetState() != transfer.States.STARTED { logger.Info("Transfer not started", "state", tReq.GetState().String()) time.Sleep(1 * time.Second) - tReq, err = s.store.GetConsumerTransfer(ctx, transferConsumerPID) + tReq, err = s.store.GetTransferR(ctx, transferConsumerPID, dspconstants.DataspaceConsumer) if err != nil { return nil, status.Errorf( codes.Internal, "could not get consumer contract with PID %s: %s", transferConsumerPID, err, @@ -279,19 +305,20 @@ func (s *Server) SignalTransferComplete( if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Transfer ID is not a valid UUID.") } - trReq, err := s.store.GetConsumerTransfer(ctx, id) + trReq, err := s.store.GetTransferRW(ctx, id, dspconstants.DataspaceConsumer) if err != nil { return nil, status.Errorf(codes.NotFound, "no transfer found") } - transferState := statemachine.GetTransferRequestNegotiation(s.store, trReq, s.provider, s.reconciler) + transferState := statemachine.GetTransferRequestNegotiation(trReq, s.provider, s.reconciler) apply, err := transferState.Send(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "couldn't finish transfer: %s", err) } apply() - for trReq.GetState() != statemachine.TransferRequestStates.COMPLETED { + // TODO: potentially save here + for trReq.GetState() != transfer.States.COMPLETED { time.Sleep(1 * time.Second) - trReq, err = s.store.GetConsumerTransfer(ctx, id) + trReq, err = s.store.GetTransferR(ctx, id, dspconstants.DataspaceConsumer) if err != nil { return nil, status.Errorf(codes.Internal, "could not get consumer contract with PID %s: %s", id, err) } diff --git a/dsp/persistance/badger/agreement_saver.go b/dsp/persistence/badger/agreement_saver.go similarity index 71% rename from dsp/persistance/badger/agreement_saver.go rename to dsp/persistence/badger/agreement_saver.go index b9ad2ba..adc6134 100644 --- a/dsp/persistance/badger/agreement_saver.go +++ b/dsp/persistence/badger/agreement_saver.go @@ -15,7 +15,10 @@ package badger import ( + "bytes" "context" + "encoding/gob" + "fmt" "github.com/go-dataspace/run-dsp/odrl" "github.com/google/uuid" @@ -31,7 +34,16 @@ func (sp *StorageProvider) GetAgreement( id uuid.UUID, ) (*odrl.Agreement, error) { key := mkAgreementKey(id.String()) - return get[*odrl.Agreement](sp.db, key) + b, err := get(sp.db, key) + if err != nil { + return nil, err + } + var a odrl.Agreement + dec := gob.NewDecoder(bytes.NewReader(b)) + if err := dec.Decode(&a); err != nil { + return nil, fmt.Errorf("could not encode bytes into Agreement: %w", err) + } + return &a, nil } // PutAgreement stores an agreement, but should return an error if the agreement ID already @@ -42,5 +54,10 @@ func (sp *StorageProvider) PutAgreement(ctx context.Context, agreement *odrl.Agr return err } key := mkAgreementKey(id.String()) - return put(sp.db, key, agreement) + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(agreement); err != nil { + return fmt.Errorf("could not encode ODRL Agreement: %w", err) + } + return put(sp.db, key, buf.Bytes()) } diff --git a/dsp/persistance/badger/contract_saver.go b/dsp/persistence/badger/contract_saver.go similarity index 69% rename from dsp/persistance/badger/contract_saver.go rename to dsp/persistence/badger/contract_saver.go index 36f5e5f..db5a5fa 100644 --- a/dsp/persistance/badger/contract_saver.go +++ b/dsp/persistence/badger/contract_saver.go @@ -19,7 +19,8 @@ import ( "context" "fmt" - "github.com/go-dataspace/run-dsp/dsp/statemachine" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" "github.com/go-dataspace/run-dsp/logging" "github.com/google/uuid" ) @@ -29,17 +30,21 @@ import ( func (sp *StorageProvider) GetContractR( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, -) (*statemachine.Contract, error) { - key := statemachine.MkContractKey(pid, role) + role constants.DataspaceRole, +) (*contract.Negotiation, error) { + key := contract.GenerateKey(pid, role) logger := logging.Extract(ctx).With("pid", pid, "role", role, "key", string(key)) - contract, err := get[*statemachine.Contract](sp.db, key) + b, err := get(sp.db, key) if err != nil { logger.Error("Failed to get contract", "err", err) return nil, fmt.Errorf("could not get contract: %w", err) } - contract.SetReadOnly() - return contract, nil + negotiation, err := contract.FromBytes(b) + if err != nil { + return nil, err + } + negotiation.SetReadOnly() + return negotiation, nil } // GetContractRW gets a contract but does NOT set the read-only property, allowing changes to be saved. @@ -48,22 +53,26 @@ func (sp *StorageProvider) GetContractR( func (sp *StorageProvider) GetContractRW( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, -) (*statemachine.Contract, error) { - key := statemachine.MkContractKey(pid, role) + role constants.DataspaceRole, +) (*contract.Negotiation, error) { + key := contract.GenerateKey(pid, role) ctx, _ = logging.InjectLabels(ctx, "type", "contract", "pid", pid, "role", role, "key", string(key)) - return getLocked[*statemachine.Contract](ctx, sp, key) + b, err := getLocked(ctx, sp, key) + if err != nil { + return nil, err + } + return contract.FromBytes(b) } // PutContract saves a contract to the database. // If the contract is set to read-only, it will panic as this is a bug in the code. // It will release the lock after it has saved. -func (sp *StorageProvider) PutContract(ctx context.Context, contract *statemachine.Contract) error { +func (sp *StorageProvider) PutContract(ctx context.Context, negotiation *contract.Negotiation) error { ctx, _ = logging.InjectLabels( ctx, - "consumer_pid", contract.ConsumerPID, - "provider_pid", contract.ProviderPID, - "role", contract.Role, + "consumer_pid", negotiation.GetConsumerPID(), + "provider_pid", negotiation.GetProviderPID(), + "role", negotiation.GetRole(), ) - return putUnlock(ctx, sp, contract) + return putUnlock(ctx, sp, negotiation) } diff --git a/dsp/persistance/badger/doc.go b/dsp/persistence/badger/doc.go similarity index 100% rename from dsp/persistance/badger/doc.go rename to dsp/persistence/badger/doc.go diff --git a/dsp/persistance/badger/locking.go b/dsp/persistence/badger/locking.go similarity index 94% rename from dsp/persistance/badger/locking.go rename to dsp/persistence/badger/locking.go index 5e4b1b9..0db2296 100644 --- a/dsp/persistance/badger/locking.go +++ b/dsp/persistence/badger/locking.go @@ -64,6 +64,11 @@ func (sp *StorageProvider) ReleaseLock(ctx context.Context, k lockKey) error { logger.Debug("Attempting to release lock") err := txn.Delete(k.key()) if err != nil { + if errors.Is(err, badger.ErrKeyNotFound) { + // No lock found is essentially released, this will most likely only happen on + // first time saves. + return nil + } logger.Error("Could not release lock", "err", err) } return err diff --git a/dsp/persistence/badger/logger_adaptor.go b/dsp/persistence/badger/logger_adaptor.go new file mode 100644 index 0000000..06dfae3 --- /dev/null +++ b/dsp/persistence/badger/logger_adaptor.go @@ -0,0 +1,42 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package badger + +import ( + "fmt" + "log/slog" +) + +// logAdaptor is a simple slog adaptor so that badger will use the same logger as the rest of the code. +// Sadly this will not use slog fields as badger uses a printf style of logging. +type logAdaptor struct { + Logger *slog.Logger +} + +func (la logAdaptor) Errorf(format string, v ...interface{}) { + la.Logger.Error(fmt.Sprintf(format, v...)) +} + +func (la logAdaptor) Warningf(format string, v ...interface{}) { + la.Logger.Warn(fmt.Sprintf(format, v...)) +} + +func (la logAdaptor) Infof(format string, v ...interface{}) { + la.Logger.Info(fmt.Sprintf(format, v...)) +} + +func (la logAdaptor) Debugf(format string, v ...interface{}) { + la.Logger.Debug(fmt.Sprintf(format, v...)) +} diff --git a/dsp/persistance/badger/provider.go b/dsp/persistence/badger/provider.go similarity index 83% rename from dsp/persistance/badger/provider.go rename to dsp/persistence/badger/provider.go index 9413df6..4b4376d 100644 --- a/dsp/persistance/badger/provider.go +++ b/dsp/persistence/badger/provider.go @@ -15,9 +15,7 @@ package badger import ( - "bytes" "context" - "encoding/gob" "fmt" "time" @@ -41,6 +39,8 @@ type storageKeyGenerator interface { type writeController interface { SetReadOnly() ReadOnly() bool + Modified() bool + ToBytes() ([]byte, error) storageKeyGenerator } @@ -56,6 +56,9 @@ func New(ctx context.Context, inMemory bool, dbPath string) (*StorageProvider, e opt = badger.DefaultOptions(dbPath) dbType = "disk" } + logger := logging.Extract(ctx) + opt.WithLogger(logAdaptor{logger}) + ctx, _ = logging.InjectLabels(ctx, "module", "badger", "db_type", dbType, @@ -95,27 +98,30 @@ func (sp StorageProvider) maintenance() { } // get is a generic function that gets the bytes from the database, decodes and returns it. -func get[T any](db *badger.DB, key []byte) (T, error) { - var thing T +func get(db *badger.DB, key []byte) ([]byte, error) { + var b []byte err := db.View(func(txn *badger.Txn) error { item, err := txn.Get(key) if err != nil { return err } return item.Value(func(val []byte) error { - dec := gob.NewDecoder(bytes.NewReader(val)) - return dec.Decode(thing) + b = append([]byte{}, val...) + return nil }) }) - return thing, err + if err != nil { + return nil, err + } + return b, err } // getLocked is a generic function that wraps get in a lock/unlock. -func getLocked[T writeController]( +func getLocked( ctx context.Context, sp *StorageProvider, key []byte, -) (T, error) { +) ([]byte, error) { logger := logging.Extract(ctx) logger.Info("Acquiring lock") if err := sp.AcquireLock(ctx, newLockKey(key)); err != nil { @@ -123,27 +129,20 @@ func getLocked[T writeController]( panic("Failed to acquire lock") } logger.Info("Lock acquired, fetching") - thing, err := get[T](sp.db, key) + b, err := get(sp.db, key) if err != nil { logger.Error("Couldn't fetch from db, unlocking", "err", err) if lockErr := sp.ReleaseLock(ctx, newLockKey(key)); lockErr != nil { logger.Error("Failed to unlock, will have to depend on TTL", "err", lockErr) } - var n T - return n, fmt.Errorf("failed to fetch from db") + return nil, fmt.Errorf("failed to fetch from db") } - return thing, nil + return b, nil } -func put[T any](db *badger.DB, key []byte, thing T) error { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - err := enc.Encode(thing) - if err != nil { - return fmt.Errorf("could not encode in gob: %w", err) - } +func put(db *badger.DB, key []byte, value []byte) error { return db.Update(func(txn *badger.Txn) error { - return txn.Set(key, buf.Bytes()) + return txn.Set(key, value) }) } @@ -155,10 +154,15 @@ func putUnlock[T writeController](ctx context.Context, sp *StorageProvider, thin panic("Trying to write a read only entry") } key := thing.StorageKey() - if err := put(sp.db, key, thing); err != nil { - logger.Error("Could not save entry, not releasing lock", "err", err) - return err + if thing.Modified() { + b, err := thing.ToBytes() + if err != nil { + return err + } + if err := put(sp.db, key, b); err != nil { + logger.Error("Could not save entry, not releasing lock", "err", err) + return err + } } - return sp.ReleaseLock(ctx, newLockKey(key)) } diff --git a/dsp/persistance/badger/transfer_saver.go b/dsp/persistence/badger/transfer_saver.go similarity index 74% rename from dsp/persistance/badger/transfer_saver.go rename to dsp/persistence/badger/transfer_saver.go index f3e2ebb..b39044e 100644 --- a/dsp/persistance/badger/transfer_saver.go +++ b/dsp/persistence/badger/transfer_saver.go @@ -19,7 +19,8 @@ import ( "context" "fmt" - "github.com/go-dataspace/run-dsp/dsp/statemachine" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/logging" "github.com/google/uuid" ) @@ -29,17 +30,21 @@ import ( func (sp *StorageProvider) GetTransferR( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, -) (*statemachine.TransferRequest, error) { - key := statemachine.MkTransferKey(pid, role) + role constants.DataspaceRole, +) (*transfer.Request, error) { + key := transfer.GenerateKey(pid, role) logger := logging.Extract(ctx).With("pid", pid, "role", role, "key", string(key)) - transfer, err := get[*statemachine.TransferRequest](sp.db, key) + b, err := get(sp.db, key) if err != nil { logger.Error("Failed to get transfer", "err", err) return nil, fmt.Errorf("could not get transfer %w", err) } - transfer.SetReadOnly() - return transfer, nil + request, err := transfer.FromBytes(b) + if err != nil { + return nil, err + } + request.SetReadOnly() + return request, nil } // GetTransferRW gets a transfer but does NOT set the read-only property, allowing changes to be saved. @@ -48,22 +53,26 @@ func (sp *StorageProvider) GetTransferR( func (sp *StorageProvider) GetTransferRW( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, -) (*statemachine.TransferRequest, error) { - key := statemachine.MkTransferKey(pid, role) + role constants.DataspaceRole, +) (*transfer.Request, error) { + key := transfer.GenerateKey(pid, role) ctx, _ = logging.InjectLabels(ctx, "type", "transfer", "pid", pid, "role", role, "key", string(key)) - return getLocked[*statemachine.TransferRequest](ctx, sp, key) + b, err := getLocked(ctx, sp, key) + if err != nil { + return nil, err + } + return transfer.FromBytes(b) } // PutTransfer saves a transfer to the database. // If the transfer is set to read-only, it will panic as this is a bug in the code. // It will release the lock after it has saved. -func (sp *StorageProvider) PutTransfer(ctx context.Context, transfer *statemachine.TransferRequest) error { +func (sp *StorageProvider) PutTransfer(ctx context.Context, transfer *transfer.Request) error { ctx, _ = logging.InjectLabels( ctx, - "consumer_pid", transfer.ConsumerPID, - "provider_pid", transfer.ProviderPID, - "role", transfer.Role, + "consumer_pid", transfer.GetConsumerPID(), + "provider_pid", transfer.GetProviderPID(), + "role", transfer.GetRole(), ) return putUnlock(ctx, sp, transfer) } diff --git a/dsp/persistance/interface.go b/dsp/persistence/interface.go similarity index 83% rename from dsp/persistance/interface.go rename to dsp/persistence/interface.go index bfa9091..82853ee 100644 --- a/dsp/persistance/interface.go +++ b/dsp/persistence/interface.go @@ -14,12 +14,14 @@ // Package persistence contains the storage interfaces for the dataspace code. It also contains // constants and other shared code for the implementation packages. -package persistance +package persistence import ( "context" - "github.com/go-dataspace/run-dsp/dsp/statemachine" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/odrl" "github.com/google/uuid" ) @@ -42,18 +44,18 @@ type ContractSaver interface { GetContractR( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, - ) (*statemachine.Contract, error) + role constants.DataspaceRole, + ) (*contract.Negotiation, error) // GetContractRW gets a read/write version of a contract. This should set a contract specific // lock for the requested contract. GetContractRW( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, - ) (*statemachine.Contract, error) + role constants.DataspaceRole, + ) (*contract.Negotiation, error) // PutContract saves a contract, and releases the contract specific lock. If the contract // is read-only, it will return an error. - PutContract(ctx context.Context, contract *statemachine.Contract) error + PutContract(ctx context.Context, contract *contract.Negotiation) error } // AgreementSaver is an interface for storing/retrieving dataspace agreements. @@ -73,14 +75,14 @@ type TransferSaver interface { GetTransferR( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, - ) (*statemachine.TransferRequest, error) + role constants.DataspaceRole, + ) (*transfer.Request, error) // GetTransferRW gets a read/write version of a transfer request. GetTransferRW( ctx context.Context, pid uuid.UUID, - role statemachine.DataspaceRole, - ) (*statemachine.TransferRequest, error) + role constants.DataspaceRole, + ) (*transfer.Request, error) // PutTransfer saves a transfer. - PutTransfer(ctx context.Context, transfer *statemachine.TransferRequest) error + PutTransfer(ctx context.Context, transfer *transfer.Request) error } diff --git a/dsp/routing.go b/dsp/routing.go index f491da3..3c5ec2e 100644 --- a/dsp/routing.go +++ b/dsp/routing.go @@ -19,6 +19,7 @@ import ( "net/http" "net/url" + "github.com/go-dataspace/run-dsp/dsp/persistence" "github.com/go-dataspace/run-dsp/dsp/statemachine" providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" ) @@ -35,7 +36,7 @@ func GetWellKnownRoutes() http.Handler { func GetDSPRoutes( provider providerv1.ProviderServiceClient, - store statemachine.Archiver, + store persistence.StorageProvider, reconciler *statemachine.Reconciler, selfURL *url.URL, pingResponse *providerv1.PingResponse, diff --git a/dsp/statemachine/archiver.go b/dsp/statemachine/archiver.go deleted file mode 100644 index e6acd7b..0000000 --- a/dsp/statemachine/archiver.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 go-dataspace -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statemachine - -import ( - "context" - "errors" - "fmt" - "sync" - - "github.com/go-dataspace/run-dsp/odrl" - "github.com/google/uuid" -) - -var ErrNotFound = errors.New("not found") - -type Archiver interface { - GetProviderContract(ctx context.Context, providerPID uuid.UUID) (*Contract, error) - PutProviderContract(ctx context.Context, contract *Contract) error - GetConsumerContract(ctx context.Context, consumerPID uuid.UUID) (*Contract, error) - PutConsumerContract(ctx context.Context, contract *Contract) error - - GetAgreement(ctx context.Context, agreementID uuid.UUID) (*odrl.Agreement, error) - DelAgreement(ctx context.Context, agreementID uuid.UUID) error - PutAgreement(ctx context.Context, agreement *odrl.Agreement) error - - GetProviderTransfer(ctx context.Context, providerPID uuid.UUID) (*TransferRequest, error) - PutProviderTransfer(ctx context.Context, contract *TransferRequest) error - GetConsumerTransfer(ctx context.Context, consumerPID uuid.UUID) (*TransferRequest, error) - PutConsumerTransfer(ctx context.Context, contract *TransferRequest) error -} - -type MemoryArchiver struct { - providerContracts map[uuid.UUID]*Contract - consumerContracts map[uuid.UUID]*Contract - agreements map[uuid.UUID]*odrl.Agreement - providerTransfers map[uuid.UUID]*TransferRequest - consumerTransfers map[uuid.UUID]*TransferRequest - sync.RWMutex -} - -func NewMemoryArchiver() *MemoryArchiver { - return &MemoryArchiver{ - providerContracts: make(map[uuid.UUID]*Contract), - consumerContracts: make(map[uuid.UUID]*Contract), - agreements: make(map[uuid.UUID]*odrl.Agreement), - providerTransfers: make(map[uuid.UUID]*TransferRequest), - consumerTransfers: make(map[uuid.UUID]*TransferRequest), - } -} - -func (ma *MemoryArchiver) PutProviderContract(ctx context.Context, contract *Contract) error { - return ma.putContract(contract.GetProviderPID(), contract, ma.providerContracts) -} - -func (ma *MemoryArchiver) PutConsumerContract(ctx context.Context, contract *Contract) error { - return ma.putContract(contract.GetConsumerPID(), contract, ma.consumerContracts) -} - -func (ma *MemoryArchiver) putContract(pid uuid.UUID, contract *Contract, contracts map[uuid.UUID]*Contract) error { - defer ma.Unlock() - ma.Lock() - contracts[pid] = contract - return nil -} - -func (ma *MemoryArchiver) GetProviderContract(ctx context.Context, pid uuid.UUID) (*Contract, error) { - return ma.getContract(pid, ma.providerContracts) -} - -func (ma *MemoryArchiver) GetConsumerContract(ctx context.Context, pid uuid.UUID) (*Contract, error) { - return ma.getContract(pid, ma.consumerContracts) -} - -func (ma *MemoryArchiver) getContract(pid uuid.UUID, contracts map[uuid.UUID]*Contract) (*Contract, error) { - defer ma.RUnlock() - ma.RLock() - c, ok := contracts[pid] - if !ok { - return nil, ErrNotFound - } - return c, nil -} - -func (ma *MemoryArchiver) PutProviderTransfer(ctx context.Context, transfer *TransferRequest) error { - return ma.putTransfer(transfer.GetProviderPID(), transfer, ma.providerTransfers) -} - -func (ma *MemoryArchiver) PutConsumerTransfer(ctx context.Context, transfer *TransferRequest) error { - return ma.putTransfer(transfer.GetConsumerPID(), transfer, ma.consumerTransfers) -} - -func (ma *MemoryArchiver) putTransfer( - pid uuid.UUID, transfer *TransferRequest, transfers map[uuid.UUID]*TransferRequest, -) error { - defer ma.Unlock() - ma.Lock() - transfers[pid] = transfer - return nil -} - -func (ma *MemoryArchiver) GetProviderTransfer(ctx context.Context, pid uuid.UUID) (*TransferRequest, error) { - return ma.getTransfer(pid, ma.providerTransfers) -} - -func (ma *MemoryArchiver) GetConsumerTransfer(ctx context.Context, pid uuid.UUID) (*TransferRequest, error) { - return ma.getTransfer(pid, ma.consumerTransfers) -} - -func (ma *MemoryArchiver) getTransfer( - pid uuid.UUID, - transfers map[uuid.UUID]*TransferRequest, -) (*TransferRequest, error) { - defer ma.RUnlock() - ma.RLock() - c, ok := transfers[pid] - if !ok { - return nil, ErrNotFound - } - return c, nil -} - -func (ma *MemoryArchiver) GetAgreement(ctx context.Context, agreementID uuid.UUID) (*odrl.Agreement, error) { - defer ma.RUnlock() - ma.RLock() - a, ok := ma.agreements[agreementID] - if !ok { - return nil, ErrNotFound - } - return a, nil -} - -func (ma *MemoryArchiver) DelAgreement(ctx context.Context, agreementID uuid.UUID) error { - delete(ma.agreements, agreementID) - return nil -} - -func (ma *MemoryArchiver) PutAgreement(ctx context.Context, agreement *odrl.Agreement) error { - u, err := uuid.Parse(agreement.ID) - if err != nil { - return fmt.Errorf("not a valid agreement ID: %w", err) - } - ma.agreements[u] = agreement - return nil -} diff --git a/dsp/statemachine/contract.go b/dsp/statemachine/contract.go deleted file mode 100644 index cbf5a85..0000000 --- a/dsp/statemachine/contract.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024 go-dataspace -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statemachine - -import ( - "fmt" - "net/url" - "slices" - "strconv" - - "github.com/go-dataspace/run-dsp/dsp/shared" - "github.com/go-dataspace/run-dsp/odrl" - "github.com/google/uuid" -) - -type DataspaceRole int8 - -const ( - DataspaceConsumer DataspaceRole = iota - DataspaceProvider -) - -var validTransitions = map[ContractState][]ContractState{ - ContractStates.INITIAL: { - ContractStates.OFFERED, - ContractStates.REQUESTED, - ContractStates.TERMINATED, - }, - ContractStates.REQUESTED: { - ContractStates.OFFERED, - ContractStates.AGREED, - ContractStates.TERMINATED, - }, - ContractStates.OFFERED: { - ContractStates.REQUESTED, - ContractStates.ACCEPTED, - ContractStates.TERMINATED, - }, - ContractStates.ACCEPTED: { - ContractStates.AGREED, - ContractStates.TERMINATED, - }, - ContractStates.AGREED: { - ContractStates.VERIFIED, - ContractStates.TERMINATED, - }, - ContractStates.VERIFIED: { - ContractStates.FINALIZED, - ContractStates.TERMINATED, - }, - ContractStates.FINALIZED: {}, - ContractStates.TERMINATED: {}, -} - -// Contract represents a contract negotiation. -type Contract struct { - ProviderPID uuid.UUID - ConsumerPID uuid.UUID - State ContractState - Offer odrl.Offer - Agreement odrl.Agreement - Callback *url.URL - Self *url.URL - Role DataspaceRole - - initial bool - ro bool -} - -func (cn *Contract) GetProviderPID() uuid.UUID { return cn.ProviderPID } -func (cn *Contract) SetProviderPID(u uuid.UUID) { cn.ProviderPID = u } -func (cn *Contract) GetConsumerPID() uuid.UUID { return cn.ConsumerPID } -func (cn *Contract) SetConsumerPID(u uuid.UUID) { cn.ProviderPID = u } -func (cn *Contract) GetState() ContractState { return cn.State } -func (cn *Contract) GetOffer() odrl.Offer { return cn.Offer } -func (cn *Contract) GetAgreement() odrl.Agreement { return cn.Agreement } -func (cn *Contract) GetRole() DataspaceRole { return cn.Role } -func (cn *Contract) GetCallback() *url.URL { return cn.Callback } -func (cn *Contract) GetSelf() *url.URL { return cn.Self } -func (cn *Contract) GetContract() *Contract { return cn } - -func (cn *Contract) SetReadOnly() { cn.ro = true } -func (cn *Contract) ReadOnly() bool { return cn.ro } - -func (cn *Contract) StorageKey() []byte { - id := cn.ConsumerPID - if cn.Role == DataspaceProvider { - id = cn.ProviderPID - } - return MkTransferKey(id, cn.Role) -} - -func (cn *Contract) SetState(state ContractState) error { - if !slices.Contains(validTransitions[cn.State], state) { - return fmt.Errorf("can't transition from %s to %s", cn.State, state) - } - cn.State = state - return nil -} - -// SetCallback sets the remote callback root. -func (cn *Contract) SetCallback(u string) error { - nu, err := url.Parse(u) - if err != nil { - return fmt.Errorf("invalid URL: %w", err) - } - cn.Callback = nu - return nil -} - -// GetContractNegotiation returns a ContractNegotion message. -func (cn *Contract) GetContractNegotiation() shared.ContractNegotiation { - return shared.ContractNegotiation{ - Context: dspaceContext, - Type: "dspace:ContractNegotiation", - ConsumerPID: cn.GetConsumerPID().URN(), - ProviderPID: cn.GetProviderPID().URN(), - State: cn.GetState().String(), - } -} - -// Copy does a deep copy of a contract, here mostly for a workaround that will go away once -// we implement a reconciliation loop. -func (cn *Contract) Copy() *Contract { - return &Contract{ - ProviderPID: cn.ProviderPID, - ConsumerPID: cn.ConsumerPID, - State: cn.State, - Offer: cn.Offer, - Agreement: cn.Agreement, - Callback: mustURL(cn.Callback), - Self: mustURL(cn.Self), - Role: cn.Role, - initial: cn.initial, - } -} - -func mustURL(u *url.URL) *url.URL { - n, err := url.Parse(u.String()) - if err != nil { - panic(err.Error()) - } - return n -} - -func MkContractKey(id uuid.UUID, role DataspaceRole) []byte { - return []byte("contract-" + id.String() + "-" + strconv.Itoa(int(role))) -} diff --git a/dsp/statemachine/contract_messages.go b/dsp/statemachine/contract_messages.go index 1365a8e..41df246 100644 --- a/dsp/statemachine/contract_messages.go +++ b/dsp/statemachine/contract_messages.go @@ -21,6 +21,8 @@ import ( "path" "time" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/logging" "github.com/go-dataspace/run-dsp/odrl" @@ -37,14 +39,14 @@ func cloneURL(u *url.URL) *url.URL { func makeContractRequestFunction( ctx context.Context, - c *Contract, + c *contract.Negotiation, cu *url.URL, reqBody []byte, - destinationState ContractState, + destinationState contract.State, reconciler *Reconciler, ) func() { var id uuid.UUID - if c.GetRole() == DataspaceConsumer { + if c.GetRole() == constants.DataspaceConsumer { id = c.GetConsumerPID() } else { id = c.GetProviderPID() @@ -66,7 +68,7 @@ func makeRequestFunction( cu *url.URL, reqBody []byte, id uuid.UUID, - role DataspaceRole, + role constants.DataspaceRole, destinationState string, recType ReconciliationType, reconciler *Reconciler, @@ -86,10 +88,10 @@ func makeRequestFunction( } //nolint:dupl -func sendContractRequest(ctx context.Context, r *Reconciler, c *Contract) (func(), error) { +func sendContractRequest(ctx context.Context, r *Reconciler, c *contract.Negotiation) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendContractRequest") contractRequest := shared.ContractRequestMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:ContractRequestMessage", ConsumerPID: c.GetConsumerPID().URN(), Offer: c.GetOffer().MessageOffer, @@ -119,16 +121,16 @@ func sendContractRequest(ctx context.Context, r *Reconciler, c *Contract) (func( c, cu, reqBody, - ContractStates.REQUESTED, + contract.States.REQUESTED, r, ), nil } //nolint:dupl -func sendContractOffer(ctx context.Context, r *Reconciler, c *Contract) (func(), error) { +func sendContractOffer(ctx context.Context, r *Reconciler, c *contract.Negotiation) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendContractOffer") contractOffer := shared.ContractOfferMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:ContractOfferMessage", ProviderPID: c.GetProviderPID().URN(), Offer: c.GetOffer().MessageOffer, @@ -160,26 +162,26 @@ func sendContractOffer(ctx context.Context, r *Reconciler, c *Contract) (func(), c, cu, reqBody, - ContractStates.OFFERED, + contract.States.OFFERED, r, ), nil } -func sendContractAgreement(ctx context.Context, r *Reconciler, c *Contract, a Archiver) (func(), error) { +func sendContractAgreement(ctx context.Context, r *Reconciler, c *contract.Negotiation) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendContractAgreement") - c.Agreement = odrl.Agreement{ + c.SetAgreement(&odrl.Agreement{ PolicyClass: odrl.PolicyClass{}, Type: "odrl:Agreement", ID: uuid.New().URN(), Target: c.GetOffer().Target, Timestamp: time.Now(), - } + }) contractAgreement := shared.ContractAgreementMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:ContractAgreementMessage", ProviderPID: c.GetProviderPID().URN(), ConsumerPID: c.GetConsumerPID().URN(), - Agreement: c.GetAgreement(), + Agreement: *c.GetAgreement(), CallbackAddress: c.GetSelf().String(), } @@ -188,10 +190,6 @@ func sendContractAgreement(ctx context.Context, r *Reconciler, c *Contract, a Ar logger.Error("Couldn't validate contract agreement", "err", err) return func() {}, fmt.Errorf("couldn't validate contract agreement: %w", err) } - if err := a.PutAgreement(ctx, &c.Agreement); err != nil { - logger.Error("Couldn't validate contract agreement", "err", err) - return func() {}, fmt.Errorf("couldn't validate contract agreement: %w", err) - } cu := cloneURL(c.GetCallback()) cu.Path = path.Join(cu.Path, "negotiations", c.GetConsumerPID().String(), "agreement") @@ -200,17 +198,17 @@ func sendContractAgreement(ctx context.Context, r *Reconciler, c *Contract, a Ar c, cu, reqBody, - ContractStates.AGREED, + contract.States.AGREED, r, ), nil } func sendContractEvent( - ctx context.Context, r *Reconciler, c *Contract, pid uuid.UUID, state ContractState, + ctx context.Context, r *Reconciler, c *contract.Negotiation, pid uuid.UUID, state contract.State, ) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendContractEvent") contractEvent := shared.ContractNegotiationEventMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:ContractNegotiationEventMessage", ProviderPID: c.GetProviderPID().URN(), ConsumerPID: c.GetConsumerPID().URN(), @@ -234,10 +232,10 @@ func sendContractEvent( ), nil } -func sendContractVerification(ctx context.Context, r *Reconciler, c *Contract) (func(), error) { +func sendContractVerification(ctx context.Context, r *Reconciler, c *contract.Negotiation) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendContractVerification") contractVerification := shared.ContractAgreementVerificationMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:ContractAgreementVerificationMessage", ProviderPID: c.GetProviderPID().URN(), ConsumerPID: c.GetConsumerPID().URN(), @@ -257,7 +255,7 @@ func sendContractVerification(ctx context.Context, r *Reconciler, c *Contract) ( c, cu, reqBody, - ContractStates.VERIFIED, + contract.States.VERIFIED, r, ), nil } diff --git a/dsp/statemachine/contract_statemachine_test.go b/dsp/statemachine/contract_statemachine_test.go index 49993bf..1a4e7d4 100644 --- a/dsp/statemachine/contract_statemachine_test.go +++ b/dsp/statemachine/contract_statemachine_test.go @@ -16,20 +16,19 @@ package statemachine_test import ( "context" - "encoding/json" "net/url" "testing" - "time" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" + "github.com/go-dataspace/run-dsp/dsp/persistence/badger" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/dsp/statemachine" "github.com/go-dataspace/run-dsp/logging" mockprovider "github.com/go-dataspace/run-dsp/mocks/github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" "github.com/go-dataspace/run-dsp/odrl" - providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) type MockRequester struct { @@ -56,526 +55,14 @@ func (mr *MockRequester) SendHTTPRequest( return mr.Response, nil } -func decode[T any](d []byte) (T, error) { - var m T - err := json.Unmarshal(d, &m) - return m, err -} - -const ( - reconcileWait = 1 * time.Second -) - var ( target = uuid.MustParse("68d3d534-06b9-4700-9890-915bc32ecb75") consumerPID = uuid.MustParse("d6bc4c28-973b-4c2f-b63f-08076c4fc65e") providerPID = uuid.MustParse("76e705bb-cd5a-49f3-99c2-cec1406c8e9e") providerCallback = urlMustParse("https://provider.dsp/") consumerCallback = urlMustParse("https://consumer.dsp/callback/") - provInitCB = urlMustParse("https://consumer.dsp/") - - publishURL = "https://example.org/publish-here.pdf" - token = "some-test-token" ) -// TestStateMachinesConsumerInitPull tests a whole statemachine run, this will do the happy path, -// acting like the consumer initiated it. Once the contract state machine has successfully completed, -// it will do a pull transfer request. -// TODO: This is very unreadable, clean it up. -// -//nolint:funlen,maintidx -func TestStateMachineConsumerInitConsumerPull(t *testing.T) { - t.Parallel() - - offer := odrl.Offer{ - MessageOffer: odrl.MessageOffer{ - PolicyClass: odrl.PolicyClass{ - AbstractPolicyRule: odrl.AbstractPolicyRule{}, - ID: uuid.New().URN(), - }, - Type: "odrl:Offer", - Target: target.URN(), - }, - } - - logger := logging.NewJSON("error", true) - ctx := logging.Inject(context.Background(), logger) - ctx, done := context.WithCancel(ctx) - defer done() - - store := statemachine.NewMemoryArchiver() - requester := &MockRequester{} - - mockProvider := mockprovider.NewMockProviderServiceClient(t) - mockProvider.On("GetDataset", mock.Anything, &providerv1.GetDatasetRequest{ - DatasetId: target.String(), - }).Return(&providerv1.GetDatasetResponse{ - Dataset: &providerv1.Dataset{}, - }, nil) - - reconciler := statemachine.NewReconciler(ctx, requester, store) - reconciler.Run() - - ctx, consumerInit, err := statemachine.NewContract( - ctx, store, mockProvider, reconciler, uuid.UUID{}, consumerPID, - statemachine.ContractStates.INITIAL, offer, providerCallback, consumerCallback, statemachine.DataspaceConsumer) - assert.Nil(t, err) - assert.Nil(t, consumerInit.GetArchiver().PutConsumerContract(ctx, consumerInit.GetContract())) - apply, err := consumerInit.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - consumerInitRes, err := store.GetConsumerContract(ctx, consumerPID) - validateContract(t, err, consumerInitRes, statemachine.ContractStates.REQUESTED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, providerCallback.String()+"negotiations/request", requester.ReceivedURL.String()) - - reqMSG, err := decode[shared.ContractRequestMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, consumerPID.URN(), reqMSG.ConsumerPID) - assert.Equal(t, consumerCallback.String(), reqMSG.CallbackAddress) - assert.Equal(t, target.URN(), reqMSG.Offer.Target) - - ctx, providerInit, err := statemachine.NewContract( - ctx, store, mockProvider, reconciler, uuid.UUID{}, uuid.MustParse(reqMSG.ConsumerPID), - statemachine.ContractStates.INITIAL, odrl.Offer{MessageOffer: reqMSG.Offer}, - urlMustParse(reqMSG.CallbackAddress), providerCallback, statemachine.DataspaceProvider, - ) - assert.Nil(t, err) - - ctx, nextProvider, err := providerInit.Recv(ctx, reqMSG) - validateContract(t, err, nextProvider.GetContract(), statemachine.ContractStates.REQUESTED, false) - - gProviderPID := nextProvider.GetProviderPID() - - apply, err = nextProvider.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - providerReqRes, err := store.GetProviderContract(ctx, gProviderPID) - validateContract(t, err, providerReqRes, statemachine.ContractStates.OFFERED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - consumerCallback.String()+"negotiations/"+consumerPID.String()+"/offers", - requester.ReceivedURL.String()) - - offerMSG, err := decode[shared.ContractOfferMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, consumerPID.URN(), offerMSG.ConsumerPID) - assert.Equal(t, gProviderPID.URN(), offerMSG.ProviderPID) - assert.Equal(t, providerCallback.String(), offerMSG.CallbackAddress) - assert.Equal(t, target.URN(), offerMSG.Offer.Target) - - consContract, err := store.GetConsumerContract(ctx, uuid.MustParse(offerMSG.ConsumerPID)) - validateContract(t, err, consContract, statemachine.ContractStates.REQUESTED, false) - - ctx, cState := statemachine.GetContractNegotiation(ctx, store, consContract, mockProvider, reconciler) - ctx, cNext, err := cState.Recv(ctx, offerMSG) - assert.Nil(t, err) - assert.Equal(t, statemachine.ContractStates.OFFERED, cNext.GetState()) - assert.Equal(t, providerCallback, cNext.GetCallback()) - assert.Equal(t, target.URN(), cNext.GetOffer().MessageOffer.Target) - assert.Equal(t, consumerCallback, cNext.GetSelf()) - - apply, err = cNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - cAcceptRes, err := store.GetConsumerContract(ctx, consumerPID) - validateContract(t, err, cAcceptRes, statemachine.ContractStates.ACCEPTED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - providerCallback.String()+"negotiations/"+gProviderPID.String()+"/events", - requester.ReceivedURL.String()) - - acceptMessage, err := decode[shared.ContractNegotiationEventMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, consumerPID.URN(), acceptMessage.ConsumerPID) - assert.Equal(t, gProviderPID.URN(), acceptMessage.ProviderPID) - assert.Equal(t, "dspace:ACCEPTED", acceptMessage.EventType) - - provContract, err := store.GetProviderContract(ctx, uuid.MustParse(acceptMessage.ProviderPID)) - validateContract(t, err, provContract, statemachine.ContractStates.OFFERED, false) - - ctx, pState := statemachine.GetContractNegotiation(ctx, store, provContract, mockProvider, reconciler) - ctx, pNext, err := pState.Recv(ctx, acceptMessage) - validateContract(t, err, pNext.GetContract(), statemachine.ContractStates.ACCEPTED, false) - - apply, err = pNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - pAgreeRes, err := store.GetProviderContract(ctx, gProviderPID) - validateContract(t, err, pAgreeRes, statemachine.ContractStates.AGREED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - consumerCallback.String()+"negotiations/"+consumerPID.String()+"/agreement", - requester.ReceivedURL.String()) - - agreeMSG, err := decode[shared.ContractAgreementMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, consumerPID.URN(), agreeMSG.ConsumerPID) - assert.Equal(t, gProviderPID.URN(), agreeMSG.ProviderPID) - assert.Equal(t, providerCallback.String(), agreeMSG.CallbackAddress) - assert.Equal(t, target.URN(), agreeMSG.Agreement.Target) - - gAgreement := agreeMSG.Agreement - - cContract, err := store.GetConsumerContract(ctx, uuid.MustParse(acceptMessage.ConsumerPID)) - validateContract(t, err, cContract, statemachine.ContractStates.ACCEPTED, false) - - ctx, cState = statemachine.GetContractNegotiation(ctx, store, cContract, mockProvider, reconciler) - ctx, cNext, err = cState.Recv(ctx, agreeMSG) - validateContract(t, err, cNext.GetContract(), statemachine.ContractStates.AGREED, false) - - apply, err = cNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - cVerRes, err := store.GetConsumerContract(ctx, consumerPID) - validateContract(t, err, cVerRes, statemachine.ContractStates.VERIFIED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - providerCallback.String()+"negotiations/"+gProviderPID.String()+"/agreement/verification", - requester.ReceivedURL.String()) - - verMSG, err := decode[shared.ContractAgreementVerificationMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, consumerPID.URN(), verMSG.ConsumerPID) - assert.Equal(t, gProviderPID.URN(), verMSG.ProviderPID) - - pContract, err := store.GetProviderContract(ctx, uuid.MustParse(acceptMessage.ProviderPID)) - validateContract(t, err, pContract, statemachine.ContractStates.AGREED, false) - - ctx, pState = statemachine.GetContractNegotiation(ctx, store, pContract, mockProvider, reconciler) - ctx, pNext, err = pState.Recv(ctx, verMSG) - validateContract(t, err, pNext.GetContract(), statemachine.ContractStates.VERIFIED, false) - - apply, err = pNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - pFinRes, err := store.GetProviderContract(ctx, gProviderPID) - validateContract(t, err, pFinRes, statemachine.ContractStates.FINALIZED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - consumerCallback.String()+"negotiations/"+consumerPID.String()+"/events", - requester.ReceivedURL.String()) - - finMSG, err := decode[shared.ContractNegotiationEventMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, consumerPID.URN(), finMSG.ConsumerPID) - assert.Equal(t, gProviderPID.URN(), finMSG.ProviderPID) - assert.Equal(t, "dspace:FINALIZED", finMSG.EventType) - - cContract, err = store.GetConsumerContract(ctx, uuid.MustParse(finMSG.ConsumerPID)) - validateContract(t, err, cContract, statemachine.ContractStates.VERIFIED, false) - - ctx, cState = statemachine.GetContractNegotiation(ctx, store, cContract, mockProvider, reconciler) - ctx, cNext, err = cState.Recv(ctx, finMSG) - validateContract(t, err, cNext.GetContract(), statemachine.ContractStates.FINALIZED, false) - - agreementID := uuid.MustParse(gAgreement.ID) - - trCPID := uuid.New() - cTransInit, err := statemachine.NewTransferRequest( - ctx, store, mockProvider, reconciler, - trCPID, agreementID, "HTTP_PULL", - providerCallback, consumerCallback, statemachine.DataspaceConsumer, - statemachine.TransferRequestStates.TRANSFERINITIAL, nil, - ) - assert.Nil(t, err) - apply, err = cTransInit.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - cTransInitRes, err := store.GetConsumerTransfer(ctx, trCPID) - validateTransfer(t, err, cTransInitRes, statemachine.TransferRequestStates.TRANSFERREQUESTED, agreementID) - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, providerCallback.String()+"transfers/request", requester.ReceivedURL.String()) - - trReqMSG, err := decode[shared.TransferRequestMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, trCPID.URN(), trReqMSG.ConsumerPID) - assert.Equal(t, agreementID.URN(), trReqMSG.AgreementID) - assert.Equal(t, "HTTP_PULL", trReqMSG.Format) - assert.Equal(t, consumerCallback.String(), trReqMSG.CallbackAddress) - - pTransInit, err := statemachine.NewTransferRequest( - ctx, store, mockProvider, reconciler, - uuid.MustParse(trReqMSG.ConsumerPID), uuid.MustParse(trReqMSG.AgreementID), trReqMSG.Format, - urlMustParse(trReqMSG.CallbackAddress), providerCallback, statemachine.DataspaceProvider, - statemachine.TransferRequestStates.TRANSFERINITIAL, nil, - ) - assert.Nil(t, err) - pTransNext, err := pTransInit.Recv(ctx, trReqMSG) - validateTransfer( - t, err, pTransNext.GetTransferRequest(), statemachine.TransferRequestStates.TRANSFERREQUESTED, agreementID) - - trProviderPID := pTransNext.GetProviderPID() - mockProvider.On("PublishDataset", mock.Anything, &providerv1.PublishDatasetRequest{ - DatasetId: target.String(), - PublishId: trProviderPID.String(), - }).Return(&providerv1.PublishDatasetResponse{ - PublishInfo: &providerv1.PublishInfo{ - Url: publishURL, - AuthenticationType: providerv1.AuthenticationType_AUTHENTICATION_TYPE_BEARER, - Username: "", - Password: token, - }, - }, nil) - - apply, err = pTransNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - pTransRes, err := store.GetProviderTransfer(ctx, trProviderPID) - validateTransfer(t, err, pTransRes, statemachine.TransferRequestStates.STARTED, agreementID) - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, consumerCallback.String()+"transfers/"+trCPID.String()+"/start", requester.ReceivedURL.String()) - - trStartMSG, err := decode[shared.TransferStartMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, trProviderPID.URN(), trStartMSG.ProviderPID) - assert.Equal(t, trCPID.URN(), trStartMSG.ConsumerPID) - assert.Equal(t, publishURL, trStartMSG.DataAddress.Endpoint) - assert.Contains(t, trStartMSG.DataAddress.EndpointProperties, shared.EndpointProperty{ - Type: "dspace:EndpointProperty", - Name: "authorization", - Value: token, - }) - assert.Contains(t, trStartMSG.DataAddress.EndpointProperties, shared.EndpointProperty{ - Type: "dspace:EndpointProperty", - Name: "authType", - Value: "bearer", - }) - - cTraReq, err := store.GetConsumerTransfer(ctx, uuid.MustParse(trStartMSG.ConsumerPID)) - validateTransfer(t, err, cTraReq, statemachine.TransferRequestStates.TRANSFERREQUESTED, agreementID) - - cTransInit = statemachine.GetTransferRequestNegotiation(store, cTraReq, mockProvider, reconciler) - cTransNext, err := cTransInit.Recv(ctx, trStartMSG) - validateTransfer(t, err, cTransNext.GetTransferRequest(), statemachine.TransferRequestStates.STARTED, agreementID) - - apply, err = cTransNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - cTransRes, err := store.GetConsumerTransfer(ctx, trCPID) - validateTransfer(t, err, cTransRes, statemachine.TransferRequestStates.COMPLETED, agreementID) - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal( - t, - providerCallback.String()+"transfers/"+trProviderPID.String()+"/completion", - requester.ReceivedURL.String(), - ) - - trCompletionMSG, err := decode[shared.TransferCompletionMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, trCPID.URN(), trCompletionMSG.ConsumerPID) - assert.Equal(t, trProviderPID.URN(), trCompletionMSG.ProviderPID) - - mockProvider.On("UnpublishDataset", mock.Anything, &providerv1.UnpublishDatasetRequest{ - PublishId: trProviderPID.String(), - }).Return(&providerv1.UnpublishDatasetResponse{ - Success: true, - }, nil) - - pTransContractStarted, err := store.GetProviderTransfer(ctx, trProviderPID) - validateTransfer(t, err, pTransContractStarted, statemachine.TransferRequestStates.STARTED, agreementID) - - pTransStarted := statemachine.GetTransferRequestNegotiation(store, pTransContractStarted, mockProvider, reconciler) - pTransNext, err = pTransStarted.Recv(ctx, trCompletionMSG) - validateTransfer(t, err, pTransNext.GetTransferRequest(), statemachine.TransferRequestStates.COMPLETED, agreementID) -} - -// TestContractStateMachineConsumerInit tests a whole contract statemachine run, this will do the happy path, -// acting like the consumer initiated it. -// -//nolint:funlen -func TestContractStateMachineProviderInit(t *testing.T) { - t.Parallel() - - offer := odrl.Offer{ - MessageOffer: odrl.MessageOffer{ - PolicyClass: odrl.PolicyClass{ - AbstractPolicyRule: odrl.AbstractPolicyRule{}, - ID: uuid.New().URN(), - }, - Type: "odrl:Offer", - Target: target.URN(), - }, - } - - logger := logging.NewJSON("error", true) - ctx := logging.Inject(context.Background(), logger) - store := statemachine.NewMemoryArchiver() - requester := &MockRequester{} - - mockProvider := mockprovider.NewMockProviderServiceClient(t) - mockProvider.On("GetDataset", mock.Anything, &providerv1.GetDatasetRequest{ - DatasetId: target.String(), - }).Return(&providerv1.GetDatasetResponse{ - Dataset: &providerv1.Dataset{}, - }, nil) - - reconciler := statemachine.NewReconciler(ctx, requester, store) - reconciler.Run() - - ctx, pInit, err := statemachine.NewContract( - ctx, store, mockProvider, reconciler, providerPID, uuid.UUID{}, - statemachine.ContractStates.INITIAL, offer, provInitCB, providerCallback, statemachine.DataspaceProvider) - assert.Nil(t, err) - apply, err := pInit.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - pInitC, err := store.GetProviderContract(ctx, providerPID) - validateContract(t, err, pInitC, statemachine.ContractStates.OFFERED, true) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, provInitCB.String()+"negotiations/offers", requester.ReceivedURL.String()) - - offMSG, err := decode[shared.ContractOfferMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, providerPID.URN(), offMSG.ProviderPID) - assert.Equal(t, providerCallback.String(), offMSG.CallbackAddress) - assert.Equal(t, target.URN(), offMSG.Offer.Target) - - ctx, cInit, err := statemachine.NewContract( - ctx, store, mockProvider, reconciler, uuid.MustParse(offMSG.ProviderPID), uuid.UUID{}, - statemachine.ContractStates.INITIAL, odrl.Offer{MessageOffer: offMSG.Offer}, - urlMustParse(offMSG.CallbackAddress), consumerCallback, statemachine.DataspaceConsumer, - ) - assert.Nil(t, err) - - ctx, nextProvider, err := cInit.Recv(ctx, offMSG) - validateContract(t, err, nextProvider.GetContract(), statemachine.ContractStates.OFFERED, false) - - gConsumerPID := nextProvider.GetConsumerPID() - - apply, err = nextProvider.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - cReqRes, err := store.GetConsumerContract(ctx, gConsumerPID) - validateContract(t, err, cReqRes, statemachine.ContractStates.REQUESTED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - providerCallback.String()+"negotiations/"+providerPID.String()+"/request", - requester.ReceivedURL.String()) - - reqMSG, err := decode[shared.ContractRequestMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, gConsumerPID.URN(), reqMSG.ConsumerPID) - assert.Equal(t, providerPID.URN(), reqMSG.ProviderPID) - assert.Equal(t, consumerCallback.String(), reqMSG.CallbackAddress) - assert.Equal(t, target.URN(), reqMSG.Offer.Target) - - pReqContract, err := store.GetProviderContract(ctx, uuid.MustParse(reqMSG.ProviderPID)) - validateContract(t, err, pReqContract, statemachine.ContractStates.OFFERED, true) - - ctx, pOffState := statemachine.GetContractNegotiation(ctx, store, pReqContract, mockProvider, reconciler) - ctx, pOffNext, err := pOffState.Recv(ctx, reqMSG) - validateContract(t, err, pOffNext.GetContract(), statemachine.ContractStates.REQUESTED, false) - - apply, err = pOffNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - pAgreedRes, err := store.GetProviderContract(ctx, providerPID) - validateContract(t, err, pAgreedRes, statemachine.ContractStates.AGREED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - consumerCallback.String()+"negotiations/"+gConsumerPID.String()+"/agreement", - requester.ReceivedURL.String()) - - agreeMSG, err := decode[shared.ContractAgreementMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, gConsumerPID.URN(), agreeMSG.ConsumerPID) - assert.Equal(t, providerPID.URN(), agreeMSG.ProviderPID) - assert.Equal(t, providerCallback.String(), agreeMSG.CallbackAddress) - assert.Equal(t, target.URN(), agreeMSG.Agreement.Target) - - cContract, err := store.GetConsumerContract(ctx, uuid.MustParse(agreeMSG.ConsumerPID)) - validateContract(t, err, cContract, statemachine.ContractStates.REQUESTED, false) - - ctx, pOffState = statemachine.GetContractNegotiation(ctx, store, cContract, mockProvider, reconciler) - ctx, pOffNext, err = pOffState.Recv(ctx, agreeMSG) - validateContract(t, err, pOffNext.GetContract(), statemachine.ContractStates.AGREED, false) - - apply, err = pOffNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - cVerRes, err := store.GetConsumerContract(ctx, gConsumerPID) - validateContract(t, err, cVerRes, statemachine.ContractStates.VERIFIED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - providerCallback.String()+"negotiations/"+providerPID.String()+"/agreement/verification", - requester.ReceivedURL.String()) - - verMSG, err := decode[shared.ContractAgreementVerificationMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, gConsumerPID.URN(), verMSG.ConsumerPID) - assert.Equal(t, providerPID.URN(), verMSG.ProviderPID) - - pContract, err := store.GetProviderContract(ctx, uuid.MustParse(verMSG.ProviderPID)) - validateContract(t, err, pContract, statemachine.ContractStates.AGREED, false) - - ctx, pState := statemachine.GetContractNegotiation(ctx, store, pContract, mockProvider, reconciler) - ctx, pNext, err := pState.Recv(ctx, verMSG) - validateContract(t, err, pNext.GetContract(), statemachine.ContractStates.VERIFIED, false) - - apply, err = pNext.Send(ctx) - assert.Nil(t, err) - apply() - time.Sleep(reconcileWait) - - pFinRes, err := store.GetProviderContract(ctx, providerPID) - validateContract(t, err, pFinRes, statemachine.ContractStates.FINALIZED, false) - - assert.Equal(t, "POST", requester.ReceivedMethod) - assert.Equal(t, - consumerCallback.String()+"negotiations/"+gConsumerPID.String()+"/events", - requester.ReceivedURL.String()) - - finMSG, err := decode[shared.ContractNegotiationEventMessage](requester.ReceivedBody) - assert.Nil(t, err) - assert.Equal(t, gConsumerPID.URN(), finMSG.ConsumerPID) - assert.Equal(t, providerPID.URN(), finMSG.ProviderPID) - assert.Equal(t, "dspace:FINALIZED", finMSG.EventType) - - cContract, err = store.GetConsumerContract(ctx, uuid.MustParse(finMSG.ConsumerPID)) - validateContract(t, err, cContract, statemachine.ContractStates.VERIFIED, false) - - ctx, pOffState = statemachine.GetContractNegotiation(ctx, store, cContract, mockProvider, reconciler) - _, pOffNext, err = pOffState.Recv(ctx, finMSG) - validateContract(t, err, pOffNext.GetContract(), statemachine.ContractStates.FINALIZED, false) -} - //nolint:funlen func TestTermination(t *testing.T) { t.Parallel() @@ -596,7 +83,9 @@ func TestTermination(t *testing.T) { ctx, done := context.WithCancel(ctx) defer done() - store := statemachine.NewMemoryArchiver() + store, err := badger.New(ctx, true, "") + assert.Nil(t, err) + requester := &MockRequester{} mockProvider := mockprovider.NewMockProviderServiceClient(t) @@ -604,23 +93,23 @@ func TestTermination(t *testing.T) { reconciler := statemachine.NewReconciler(ctx, requester, store) reconciler.Run() - for _, role := range []statemachine.DataspaceRole{ - statemachine.DataspaceConsumer, - statemachine.DataspaceProvider, + for _, role := range []constants.DataspaceRole{ + constants.DataspaceConsumer, + constants.DataspaceProvider, } { - for _, state := range []statemachine.ContractState{ - statemachine.ContractStates.REQUESTED, - statemachine.ContractStates.OFFERED, - statemachine.ContractStates.ACCEPTED, - statemachine.ContractStates.AGREED, - statemachine.ContractStates.VERIFIED, + for _, state := range []contract.State{ + contract.States.REQUESTED, + contract.States.OFFERED, + contract.States.ACCEPTED, + contract.States.AGREED, + contract.States.VERIFIED, } { consumerPID := uuid.New() providerPID := uuid.New() - ctx, consumerInit, err := statemachine.NewContract( - ctx, store, mockProvider, reconciler, providerPID, consumerPID, + negotiation := contract.New( + providerPID, consumerPID, state, offer, providerCallback, consumerCallback, role) - assert.Nil(t, err) + ctx, consumerInit := statemachine.GetContractNegotiation(ctx, negotiation, mockProvider, reconciler) msg := shared.ContractNegotiationTerminationMessage{ Context: shared.GetDSPContext(), Type: "dspace:ContractNegotiationTerminationMessage", @@ -639,61 +128,7 @@ func TestTermination(t *testing.T) { assert.Nil(t, err) _, err = next.Send(ctx) assert.Nil(t, err) - var contract *statemachine.Contract - switch role { - case statemachine.DataspaceProvider: - contract, err = store.GetProviderContract(ctx, providerPID) - case statemachine.DataspaceConsumer: - contract, err = store.GetConsumerContract(ctx, consumerPID) - } - assert.Nil(t, err) - assert.Equal(t, statemachine.ContractStates.TERMINATED, contract.GetState()) + assert.Equal(t, contract.States.TERMINATED, next.GetContract().GetState()) } } } - -func validateContract( - t *testing.T, err error, c *statemachine.Contract, state statemachine.ContractState, provInit bool, -) { - t.Helper() - assert.Nil(t, err) - assert.Equal(t, state, c.GetState()) - if c.GetRole() == statemachine.DataspaceConsumer { - assert.Equal(t, providerCallback, c.GetCallback()) - assert.Equal(t, consumerCallback, c.GetSelf()) - } else { - if provInit { - assert.Equal(t, provInitCB, c.GetCallback()) - } else { - assert.Equal(t, consumerCallback, c.GetCallback()) - assert.Equal(t, providerCallback, c.GetSelf()) - } - } - assert.Equal(t, target.URN(), c.GetOffer().MessageOffer.Target) -} - -func validateTransfer( - t *testing.T, err error, c *statemachine.TransferRequest, state statemachine.TransferRequestState, - agreementID uuid.UUID, -) { - t.Helper() - assert.Nil(t, err) - assert.Equal(t, state, c.GetState()) - assert.Equal(t, agreementID, c.GetAgreementID()) - if c.GetRole() == statemachine.DataspaceConsumer { - assert.Equal(t, providerCallback, c.GetCallback()) - assert.Equal(t, consumerCallback, c.GetSelf()) - } else { - assert.Equal(t, consumerCallback, c.GetCallback()) - assert.Equal(t, providerCallback, c.GetSelf()) - } - - assert.Equal(t, statemachine.DirectionPull, c.GetTransferDirection()) - - if c.GetPublishInfo() != nil { - assert.Equal(t, publishURL, c.GetPublishInfo().Url) - assert.Equal(t, providerv1.AuthenticationType_AUTHENTICATION_TYPE_BEARER, c.GetPublishInfo().AuthenticationType) - assert.Equal(t, "", c.GetPublishInfo().Username) - assert.Equal(t, token, c.GetPublishInfo().Password) - } -} diff --git a/dsp/statemachine/contract_transitions.go b/dsp/statemachine/contract_transitions.go index 01d18df..6985e4e 100644 --- a/dsp/statemachine/contract_transitions.go +++ b/dsp/statemachine/contract_transitions.go @@ -16,13 +16,13 @@ package statemachine import ( "context" + "errors" "fmt" "net/url" "strings" + "github.com/go-dataspace/run-dsp/dsp/contract" "github.com/go-dataspace/run-dsp/dsp/shared" - "github.com/go-dataspace/run-dsp/internal/constants" - "github.com/go-dataspace/run-dsp/jsonld" "github.com/go-dataspace/run-dsp/logging" "github.com/go-dataspace/run-dsp/odrl" providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" @@ -30,19 +30,19 @@ import ( ) var ( - emptyUUID = uuid.UUID{} - dspaceContext = jsonld.NewRootContext([]jsonld.ContextEntry{{ID: constants.DSPContext}}) + emptyUUID = uuid.UUID{} + ErrNotFound = errors.New("not found") ) type Contracter interface { GetProviderPID() uuid.UUID GetConsumerPID() uuid.UUID - GetState() ContractState + GetState() contract.State GetCallback() *url.URL SetCallback(u string) error GetSelf() *url.URL - SetState(state ContractState) error - GetContract() *Contract + SetState(state contract.State) error + GetContract() *contract.Negotiation GetOffer() odrl.Offer GetContractNegotiation() shared.ContractNegotiation } @@ -53,25 +53,22 @@ type ContractNegotiationState interface { Contracter Recv(ctx context.Context, message any) (context.Context, ContractNegotiationState, error) Send(ctx context.Context) (func(), error) - GetArchiver() Archiver GetProvider() providerv1.ProviderServiceClient GetReconciler() *Reconciler } type stateMachineDeps struct { - a Archiver p providerv1.ProviderServiceClient r *Reconciler } -func (cd *stateMachineDeps) GetArchiver() Archiver { return cd.a } func (cd *stateMachineDeps) GetProvider() providerv1.ProviderServiceClient { return cd.p } func (cd *stateMachineDeps) GetReconciler() *Reconciler { return cd.r } // ContractNegotiationInitial is an initial state for a contract that hasn't been actually // been submitted yet. type ContractNegotiationInitial struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -104,17 +101,14 @@ func (cn *ContractNegotiationInitial) Recv( logger.Error("target dataset not found", "err", err) return ctx, nil, fmt.Errorf("dataset %s: %w", cn.GetOffer().Target, ErrNotFound) } - if err := cn.SetState(ContractStates.REQUESTED); err != nil { + if err := cn.SetState(contract.States.REQUESTED); err != nil { logger.Error("could not transition state", "err", err) return ctx, nil, fmt.Errorf("could not set state: %w", err) } - cn.Contract.ProviderPID = uuid.New() - cn.Contract.initial = true - if err := cn.a.PutProviderContract(ctx, cn.GetContract()); err != nil { - logger.Error("failed to save contract", "err", err) - return ctx, nil, fmt.Errorf("failed to save contract: %w", err) - } - ctx, cns := GetContractNegotiation(ctx, cn.a, cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) + cn.Negotiation.SetProviderPID(uuid.New()) + cn.Negotiation.SetInitial() + + ctx, cns := GetContractNegotiation(ctx, cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) return ctx, cns, nil case shared.ContractOfferMessage: ctx, logger = logging.InjectLabels(ctx, @@ -122,17 +116,13 @@ func (cn *ContractNegotiationInitial) Recv( "dataset_target", cn.GetOffer().Target, ) // This is the initial offer, we can assuem all data is freshly made based on the offer. - if err := cn.SetState(ContractStates.OFFERED); err != nil { + if err := cn.SetState(contract.States.OFFERED); err != nil { logger.Error("could not transition state", "err", err) return ctx, nil, fmt.Errorf("could not set state: %w", err) } - cn.Contract.ConsumerPID = uuid.New() - cn.Contract.initial = true - if err := cn.a.PutConsumerContract(ctx, cn.GetContract()); err != nil { - logger.Error("failed to save contract", "err", err) - return ctx, nil, fmt.Errorf("failed to save contract: %w", err) - } - ctx, cns := GetContractNegotiation(ctx, cn.a, cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) + cn.Negotiation.SetConsumerPID(uuid.New()) + cn.Negotiation.SetInitial() + ctx, cns := GetContractNegotiation(ctx, cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) return ctx, cns, nil default: return ctx, nil, fmt.Errorf("Message type %s is not supported at this stage", t) @@ -143,7 +133,7 @@ func (cn *ContractNegotiationInitial) Recv( // This needs either the contract's consumer or provider PID set, but not both. // If the provider PID is set, it will send out a contract offer to the callback. // If the consumer PID is set, it will send out a contract request to the callback. -func (cn *ContractNegotiationInitial) Send(ctx context.Context) (func(), error) { //nolint: cyclop +func (cn *ContractNegotiationInitial) Send(ctx context.Context) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "send_type", fmt.Sprintf("%T", cn)) if (cn.GetConsumerPID() == emptyUUID && cn.GetProviderPID() == emptyUUID) || (cn.GetConsumerPID() != emptyUUID && cn.GetProviderPID() != emptyUUID) { @@ -153,10 +143,6 @@ func (cn *ContractNegotiationInitial) Send(ctx context.Context) (func(), error) switch { case cn.GetConsumerPID() != emptyUUID: - if err := cn.a.PutConsumerContract(ctx, cn.GetContract()); err != nil { - logger.Error("failed to save contract", "err", err) - return func() {}, fmt.Errorf("could not save contract: %w", err) - } return sendContractRequest(ctx, cn.GetReconciler(), cn.GetContract()) case cn.GetProviderPID() != emptyUUID: targetID, err := shared.URNtoRawID(cn.GetOffer().Target) @@ -171,10 +157,6 @@ func (cn *ContractNegotiationInitial) Send(ctx context.Context) (func(), error) logger.Error("Dataset not found", "err", err) return nil, ErrNotFound } - if err := cn.a.PutProviderContract(ctx, cn.GetContract()); err != nil { - logger.Error("Could not send contract", "err", err) - return func() {}, fmt.Errorf("could not save contract: %w", err) - } return sendContractOffer(ctx, cn.GetReconciler(), cn.GetContract()) default: logger.Error("Could not deduce type of contract") @@ -184,7 +166,7 @@ func (cn *ContractNegotiationInitial) Send(ctx context.Context) (func(), error) // ContractNegotiationRequested represents the requested state. type ContractNegotiationRequested struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -196,16 +178,16 @@ func (cn *ContractNegotiationRequested) Recv( ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn)) logger.Debug("Receiving message") var consumerPID, providerPID, callbackAddress string - var targetState ContractState + var targetState contract.State switch t := message.(type) { case shared.ContractOfferMessage: consumerPID = t.ConsumerPID providerPID = t.ProviderPID callbackAddress = t.CallbackAddress - targetState = ContractStates.OFFERED + targetState = contract.States.OFFERED if ppid, err := uuid.Parse(providerPID); err == nil && cn.GetProviderPID() == emptyUUID { - cn.Contract.ProviderPID = ppid + cn.Negotiation.SetProviderPID(ppid) } ctx, logger = logging.InjectLabels(ctx, "recv_msg_type", fmt.Sprintf("%T", t), @@ -215,8 +197,8 @@ func (cn *ContractNegotiationRequested) Recv( consumerPID = t.ConsumerPID providerPID = t.ProviderPID callbackAddress = t.CallbackAddress - cn.Contract.Agreement = t.Agreement - targetState = ContractStates.AGREED + cn.Negotiation.SetAgreement(&t.Agreement) + targetState = contract.States.AGREED ctx, logger = logging.InjectLabels(ctx, "recv_msg_type", fmt.Sprintf("%T", t), ) @@ -233,16 +215,16 @@ func (cn *ContractNegotiationRequested) Recv( func (cn *ContractNegotiationRequested) Send(ctx context.Context) (func(), error) { ctx, _ = logging.InjectLabels(ctx, "send_type", fmt.Sprintf("%T", cn)) // Detect if this is a consumer initiated or provider initiated request. - if cn.Contract.initial { - cn.Contract.initial = false + if cn.Negotiation.Initial() { + cn.Negotiation.UnsetInitial() return sendContractOffer(ctx, cn.GetReconciler(), cn.GetContract()) } else { - return sendContractAgreement(ctx, cn.GetReconciler(), cn.GetContract(), cn.GetArchiver()) + return sendContractAgreement(ctx, cn.GetReconciler(), cn.GetContract()) } } type ContractNegotiationOffered struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -254,16 +236,16 @@ func (cn *ContractNegotiationOffered) Recv( ctx, logger := logging.InjectLabels(ctx, "recv_type", fmt.Sprintf("%T", cn)) logger.Debug("Receiving message") var consumerPID, providerPID, callbackAddress string - var targetState ContractState + var targetState contract.State switch t := message.(type) { case shared.ContractRequestMessage: consumerPID = t.ConsumerPID providerPID = t.ProviderPID callbackAddress = t.CallbackAddress - targetState = ContractStates.REQUESTED + targetState = contract.States.REQUESTED if ppid, err := uuid.Parse(consumerPID); err == nil && cn.GetConsumerPID() == emptyUUID { - cn.Contract.ConsumerPID = ppid + cn.Negotiation.SetConsumerPID(ppid) } ctx, logger = logging.InjectLabels(ctx, "recv_msg_type", fmt.Sprintf("%T", t), @@ -277,12 +259,12 @@ func (cn *ContractNegotiationOffered) Recv( consumerPID = t.ConsumerPID providerPID = t.ProviderPID callbackAddress = cn.GetCallback().String() - receivedStatus, err := ParseContractState(t.EventType) + receivedStatus, err := contract.ParseState(t.EventType) if err != nil { logger.Error("Event contained invalid status", "err", err) return ctx, nil, fmt.Errorf("event %s does not contain proper status: %w", t.EventType, err) } - if receivedStatus != ContractStates.ACCEPTED { + if receivedStatus != contract.States.ACCEPTED { logger.Error("Event contained invalid status", "err", err) return ctx, nil, fmt.Errorf("invalid status: %s", receivedStatus) } @@ -299,17 +281,17 @@ func (cn *ContractNegotiationOffered) Recv( func (cn *ContractNegotiationOffered) Send(ctx context.Context) (func(), error) { ctx, _ = logging.InjectLabels(ctx, "send_type", fmt.Sprintf("%T", cn)) // Detect if this is a consumer initiated or provider initiated request. - if cn.Contract.initial { - cn.Contract.initial = false + if cn.Negotiation.Initial() { + cn.Negotiation.UnsetInitial() return sendContractRequest(ctx, cn.GetReconciler(), cn.GetContract()) } else { return sendContractEvent( - ctx, cn.GetReconciler(), cn.GetContract(), cn.GetProviderPID(), ContractStates.ACCEPTED) + ctx, cn.GetReconciler(), cn.GetContract(), cn.GetProviderPID(), contract.States.ACCEPTED) } } type ContractNegotiationAccepted struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -325,8 +307,8 @@ func (cn *ContractNegotiationAccepted) Recv( "recv_msg_type", fmt.Sprintf("%T", t), ) logger.Debug("Received message") - cn.Agreement = t.Agreement - return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, t.CallbackAddress, ContractStates.AGREED) + cn.SetAgreement(&t.Agreement) + return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, t.CallbackAddress, contract.States.AGREED) case shared.ContractNegotiationTerminationMessage: return processTermination(ctx, t, cn) default: @@ -336,11 +318,11 @@ func (cn *ContractNegotiationAccepted) Recv( func (cn *ContractNegotiationAccepted) Send(ctx context.Context) (func(), error) { ctx, _ = logging.InjectLabels(ctx, "send_type", fmt.Sprintf("%T", cn)) - return sendContractAgreement(ctx, cn.GetReconciler(), cn.GetContract(), cn.GetArchiver()) + return sendContractAgreement(ctx, cn.GetReconciler(), cn.GetContract()) } type ContractNegotiationAgreed struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -354,7 +336,7 @@ func (cn *ContractNegotiationAgreed) Recv( ctx, _ = logging.InjectLabels(ctx, "recv_msg_type", fmt.Sprintf("%T", t), ) - return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.VERIFIED) + return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), contract.States.VERIFIED) case shared.ContractNegotiationTerminationMessage: return processTermination(ctx, t, cn) default: @@ -368,7 +350,7 @@ func (cn *ContractNegotiationAgreed) Send(ctx context.Context) (func(), error) { } type ContractNegotiationVerified struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -383,18 +365,18 @@ func (cn *ContractNegotiationVerified) Recv( "recv_msg_type", fmt.Sprintf("%T", t), "event_type", t.EventType, ) - receivedStatus, err := ParseContractState(t.EventType) + receivedStatus, err := contract.ParseState(t.EventType) if err != nil { logger.Error("event does not contain the proper status", "err", err) return ctx, nil, fmt.Errorf("event %s does not contain proper status: %w", t.EventType, err) } - if receivedStatus != ContractStates.FINALIZED { + if receivedStatus != contract.States.FINALIZED { logger.Error("invalid status") return ctx, nil, fmt.Errorf("invalid status: %s", receivedStatus) } logger.Debug("Received message") return verifyAndTransform( - ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.FINALIZED) + ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), contract.States.FINALIZED) case shared.ContractNegotiationTerminationMessage: return processTermination(ctx, t, cn) default: @@ -405,11 +387,11 @@ func (cn *ContractNegotiationVerified) Recv( func (cn *ContractNegotiationVerified) Send(ctx context.Context) (func(), error) { ctx, _ = logging.InjectLabels(ctx, "send_type", fmt.Sprintf("%T", cn)) return sendContractEvent( - ctx, cn.GetReconciler(), cn.GetContract(), cn.GetConsumerPID(), ContractStates.FINALIZED) + ctx, cn.GetReconciler(), cn.GetContract(), cn.GetConsumerPID(), contract.States.FINALIZED) } type ContractNegotiationFinalized struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -424,7 +406,7 @@ func (cn *ContractNegotiationFinalized) Send(ctx context.Context) (func(), error } type ContractNegotiationTerminated struct { - *Contract + *contract.Negotiation stateMachineDeps } @@ -439,65 +421,31 @@ func (cn *ContractNegotiationTerminated) Send(ctx context.Context) (func(), erro return func() {}, nil } -func NewContract( - ctx context.Context, - store Archiver, - provider providerv1.ProviderServiceClient, - reconciler *Reconciler, - providerPID, consumerPID uuid.UUID, - state ContractState, - offer odrl.Offer, - callback, self *url.URL, - role DataspaceRole, -) (context.Context, ContractNegotiationState, error) { - contract := &Contract{ - ProviderPID: providerPID, - ConsumerPID: consumerPID, - State: state, - Offer: offer, - Callback: callback, - Self: self, - Role: role, - } - var err error - if role == DataspaceConsumer { - err = store.PutConsumerContract(ctx, contract) - } else { - err = store.PutProviderContract(ctx, contract) - } - if err != nil { - return ctx, nil, err - } - ctx, cn := GetContractNegotiation(ctx, store, contract, provider, reconciler) - return ctx, cn, nil -} - func GetContractNegotiation( ctx context.Context, - store Archiver, - c *Contract, + c *contract.Negotiation, p providerv1.ProviderServiceClient, r *Reconciler, ) (context.Context, ContractNegotiationState) { var cns ContractNegotiationState - deps := stateMachineDeps{a: store, p: p, r: r} + deps := stateMachineDeps{p: p, r: r} switch c.GetState() { - case ContractStates.INITIAL: - cns = &ContractNegotiationInitial{Contract: c, stateMachineDeps: deps} - case ContractStates.REQUESTED: - cns = &ContractNegotiationRequested{Contract: c, stateMachineDeps: deps} - case ContractStates.OFFERED: - cns = &ContractNegotiationOffered{Contract: c, stateMachineDeps: deps} - case ContractStates.AGREED: - cns = &ContractNegotiationAgreed{Contract: c, stateMachineDeps: deps} - case ContractStates.ACCEPTED: - cns = &ContractNegotiationAccepted{Contract: c, stateMachineDeps: deps} - case ContractStates.VERIFIED: - cns = &ContractNegotiationVerified{Contract: c, stateMachineDeps: deps} - case ContractStates.FINALIZED: - cns = &ContractNegotiationFinalized{Contract: c, stateMachineDeps: deps} - case ContractStates.TERMINATED: - cns = &ContractNegotiationTerminated{Contract: c, stateMachineDeps: deps} + case contract.States.INITIAL: + cns = &ContractNegotiationInitial{Negotiation: c, stateMachineDeps: deps} + case contract.States.REQUESTED: + cns = &ContractNegotiationRequested{Negotiation: c, stateMachineDeps: deps} + case contract.States.OFFERED: + cns = &ContractNegotiationOffered{Negotiation: c, stateMachineDeps: deps} + case contract.States.AGREED: + cns = &ContractNegotiationAgreed{Negotiation: c, stateMachineDeps: deps} + case contract.States.ACCEPTED: + cns = &ContractNegotiationAccepted{Negotiation: c, stateMachineDeps: deps} + case contract.States.VERIFIED: + cns = &ContractNegotiationVerified{Negotiation: c, stateMachineDeps: deps} + case contract.States.FINALIZED: + cns = &ContractNegotiationFinalized{Negotiation: c, stateMachineDeps: deps} + case contract.States.TERMINATED: + cns = &ContractNegotiationTerminated{Negotiation: c, stateMachineDeps: deps} default: panic("Invalid contract state.") } @@ -505,7 +453,7 @@ func GetContractNegotiation( "contract_consumerPID", cns.GetConsumerPID().String(), "contract_providerPID", cns.GetProviderPID().String(), "contract_state", cns.GetState().String(), - "contract_role", cns.GetContract().Role, + "contract_role", cns.GetContract().GetRole(), ) logger.Debug("Found contract") return ctx, cns @@ -515,7 +463,7 @@ func verifyAndTransform( ctx context.Context, cn ContractNegotiationState, providerPID, consumerPID, callbackAddress string, - targetState ContractState, + targetState contract.State, ) (context.Context, ContractNegotiationState, error) { ctx, logger := logging.InjectLabels(ctx, "target_state", targetState) if cn.GetProviderPID().URN() != strings.ToLower(providerPID) { @@ -552,25 +500,7 @@ func verifyAndTransform( return ctx, nil, fmt.Errorf("could not set state: %w", err) } - if cn.GetContract().Role == DataspaceConsumer { - err = cn.GetArchiver().PutConsumerContract(ctx, cn.GetContract()) - } else { - err = cn.GetArchiver().PutProviderContract(ctx, cn.GetContract()) - } - if err != nil { - logger.Error("Could not set state", "err", err) - return ctx, nil, fmt.Errorf("failed to save contract: %w", err) - } - - if cn.GetContract().Role == DataspaceConsumer && targetState == ContractStates.FINALIZED { - err = cn.GetArchiver().PutAgreement(ctx, &cn.GetContract().Copy().Agreement) - if err != nil { - logger.Error("Could not set state", "err", err) - return ctx, nil, fmt.Errorf("failed to save agreement: %w", err) - } - } - - ctx, cns := GetContractNegotiation(ctx, cn.GetArchiver(), cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) + ctx, cns := GetContractNegotiation(ctx, cn.GetContract(), cn.GetProvider(), cn.GetReconciler()) return ctx, cns, nil } @@ -583,5 +513,5 @@ func processTermination( logger = logger.With(fmt.Sprintf("reason_%s", reason.Language), reason.Value) } ctx = logging.Inject(ctx, logger) - return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), ContractStates.TERMINATED) + return verifyAndTransform(ctx, cn, t.ProviderPID, t.ConsumerPID, cn.GetCallback().String(), contract.States.TERMINATED) } diff --git a/dsp/statemachine/contractstates_enums.go b/dsp/statemachine/contractstates_enums.go deleted file mode 100644 index 243103a..0000000 --- a/dsp/statemachine/contractstates_enums.go +++ /dev/null @@ -1,194 +0,0 @@ -// Code generated by goenums. DO NOT EDIT. -// This file was generated by github.com/zarldev/goenums -// using the command: -// goenums contract_state.go - -package statemachine - -import ( - "bytes" - "database/sql/driver" - "fmt" - "strconv" -) - -type ContractState struct { - contractState -} - -type contractstatesContainer struct { - INITIAL ContractState - REQUESTED ContractState - OFFERED ContractState - AGREED ContractState - ACCEPTED ContractState - VERIFIED ContractState - FINALIZED ContractState - TERMINATED ContractState -} - -var ContractStates = contractstatesContainer{ - INITIAL: ContractState{ - contractState: initial, - }, - REQUESTED: ContractState{ - contractState: requested, - }, - OFFERED: ContractState{ - contractState: offered, - }, - AGREED: ContractState{ - contractState: agreed, - }, - ACCEPTED: ContractState{ - contractState: accepted, - }, - VERIFIED: ContractState{ - contractState: verified, - }, - FINALIZED: ContractState{ - contractState: finalized, - }, - TERMINATED: ContractState{ - contractState: terminated, - }, -} - -func (c contractstatesContainer) All() []ContractState { - return []ContractState{ - c.INITIAL, - c.REQUESTED, - c.OFFERED, - c.AGREED, - c.ACCEPTED, - c.VERIFIED, - c.FINALIZED, - c.TERMINATED, - } -} - -var invalidContractState = ContractState{} - -func ParseContractState(a any) (ContractState, error) { - res := invalidContractState - switch v := a.(type) { - case ContractState: - return v, nil - case []byte: - res = stringToContractState(string(v)) - case string: - res = stringToContractState(v) - case fmt.Stringer: - res = stringToContractState(v.String()) - case int: - res = intToContractState(v) - case int64: - res = intToContractState(int(v)) - case int32: - res = intToContractState(int(v)) - } - return res, nil -} - -func stringToContractState(s string) ContractState { - switch s { - case "INITIAL": - return ContractStates.INITIAL - case "dspace:REQUESTED": - return ContractStates.REQUESTED - case "dspace:OFFERED": - return ContractStates.OFFERED - case "dspace:AGREED": - return ContractStates.AGREED - case "dspace:ACCEPTED": - return ContractStates.ACCEPTED - case "dspace:VERIFIED": - return ContractStates.VERIFIED - case "dspace:FINALIZED": - return ContractStates.FINALIZED - case "dspace:TERMINATED": - return ContractStates.TERMINATED - } - return invalidContractState -} - -func intToContractState(i int) ContractState { - if i < 0 || i >= len(ContractStates.All()) { - return invalidContractState - } - return ContractStates.All()[i] -} - -func ExhaustiveContractStates(f func(ContractState)) { - for _, p := range ContractStates.All() { - f(p) - } -} - -var validContractStates = map[ContractState]bool{ - ContractStates.INITIAL: true, - ContractStates.REQUESTED: true, - ContractStates.OFFERED: true, - ContractStates.AGREED: true, - ContractStates.ACCEPTED: true, - ContractStates.VERIFIED: true, - ContractStates.FINALIZED: true, - ContractStates.TERMINATED: true, -} - -func (p ContractState) IsValid() bool { - return validContractStates[p] -} - -func (p ContractState) MarshalJSON() ([]byte, error) { - return []byte(`"` + p.String() + `"`), nil -} - -func (p *ContractState) UnmarshalJSON(b []byte) error { - b = bytes.Trim(bytes.Trim(b, `"`), ` `) - newp, err := ParseContractState(b) - if err != nil { - return err - } - *p = newp - return nil -} - -func (p *ContractState) Scan(value any) error { - newp, err := ParseContractState(value) - if err != nil { - return err - } - *p = newp - return nil -} - -func (p ContractState) Value() (driver.Value, error) { - return p.String(), nil -} - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the goenums command to generate them again. - // Does not identify newly added constant values unless order changes - var x [1]struct{} - _ = x[initial-0] - _ = x[requested-1] - _ = x[offered-2] - _ = x[agreed-3] - _ = x[accepted-4] - _ = x[verified-5] - _ = x[finalized-6] - _ = x[terminated-7] -} - -const _contractstates_name = "INITIALdspace:REQUESTEDdspace:OFFEREDdspace:AGREEDdspace:ACCEPTEDdspace:VERIFIEDdspace:FINALIZEDdspace:TERMINATED" - -var _contractstates_index = [...]uint16{0, 7, 23, 37, 50, 65, 80, 96, 113} - -func (i contractState) String() string { - if i < 0 || i >= contractState(len(_contractstates_index)-1) { - return "contractstates(" + (strconv.FormatInt(int64(i), 10) + ")") - } - return _contractstates_name[_contractstates_index[i]:_contractstates_index[i+1]] -} diff --git a/dsp/statemachine/contractstates_gob.go b/dsp/statemachine/contractstates_gob.go deleted file mode 100644 index d000754..0000000 --- a/dsp/statemachine/contractstates_gob.go +++ /dev/null @@ -1,15 +0,0 @@ -package statemachine - -func (p ContractState) GobEncode() ([]byte, error) { - return []byte(p.String()), nil -} - -func (p *ContractState) GobDecode(b []byte) error { - newp, err := ParseContractState(b) - if err != nil { - return err - } - - *p = newp - return nil -} diff --git a/dsp/statemachine/reconciler.go b/dsp/statemachine/reconciler.go index 2fe7add..452361a 100644 --- a/dsp/statemachine/reconciler.go +++ b/dsp/statemachine/reconciler.go @@ -24,7 +24,11 @@ import ( "time" "github.com/gammazero/deque" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/contract" + "github.com/go-dataspace/run-dsp/dsp/persistence" "github.com/go-dataspace/run-dsp/dsp/shared" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/logging" "github.com/google/uuid" ) @@ -66,7 +70,7 @@ type reconciliationOperation struct { type ReconciliationEntry struct { EntityID uuid.UUID Type ReconciliationType - Role DataspaceRole + Role constants.DataspaceRole TargetState string Method string URL *url.URL @@ -84,7 +88,7 @@ type Reconciler struct { ctx context.Context c chan reconciliationOperation r shared.Requester - a Archiver + s persistence.StorageProvider q *deque.Deque[reconciliationOperation] // Waitgroup to keep track of management/worker processes, not called from the command yet, @@ -93,7 +97,7 @@ type Reconciler struct { sync.Mutex } -func NewReconciler(ctx context.Context, r shared.Requester, a Archiver) *Reconciler { +func NewReconciler(ctx context.Context, r shared.Requester, s persistence.StorageProvider) *Reconciler { q := &deque.Deque[reconciliationOperation]{} q.Grow(initialQueueSize) @@ -101,7 +105,7 @@ func NewReconciler(ctx context.Context, r shared.Requester, a Archiver) *Reconci ctx: ctx, c: make(chan reconciliationOperation), r: r, - a: a, + s: s, q: q, } } @@ -270,20 +274,14 @@ func (r *Reconciler) updateState( } } -//nolint:dupl func (c *Reconciler) setTransferState( - ctx context.Context, state string, role DataspaceRole, id uuid.UUID, + ctx context.Context, state string, role constants.DataspaceRole, id uuid.UUID, ) error { - ts, err := ParseTransferRequestState(state) + ts, err := transfer.ParseState(state) if err != nil { return fmt.Errorf("%w: Invalid state: %w", ErrFatal, err) } - var tr *TransferRequest - if role == DataspaceConsumer { - tr, err = c.a.GetConsumerTransfer(ctx, id) - } else { - tr, err = c.a.GetProviderTransfer(ctx, id) - } + tr, err := c.s.GetTransferRW(ctx, id, role) if err != nil { return fmt.Errorf("Can't find transfer request: %w", err) } @@ -291,31 +289,22 @@ func (c *Reconciler) setTransferState( if err != nil { return fmt.Errorf("Can't change state: %w", err) } - if role == DataspaceConsumer { - err = c.a.PutConsumerTransfer(ctx, tr) - } else { - err = c.a.PutProviderTransfer(ctx, tr) - } + err = c.s.PutTransfer(ctx, tr) if err != nil { return fmt.Errorf("Can't save transfer request: %w", err) } return nil } -//nolint:dupl func (c *Reconciler) setContractState( - ctx context.Context, state string, role DataspaceRole, id uuid.UUID, + ctx context.Context, state string, role constants.DataspaceRole, id uuid.UUID, ) error { - cs, err := ParseContractState(state) + cs, err := contract.ParseState(state) if err != nil { return fmt.Errorf("%w: Invalid state: %w", ErrFatal, err) } - var con *Contract - if role == DataspaceConsumer { - con, err = c.a.GetConsumerContract(ctx, id) - } else { - con, err = c.a.GetProviderContract(ctx, id) - } + var con *contract.Negotiation + con, err = c.s.GetContractRW(ctx, id, role) if err != nil { return fmt.Errorf("Can't find contract: %w", err) } @@ -323,11 +312,7 @@ func (c *Reconciler) setContractState( if err != nil { return fmt.Errorf("Can't change state: %w", err) } - if role == DataspaceConsumer { - err = c.a.PutConsumerContract(ctx, con) - } else { - err = c.a.PutProviderContract(ctx, con) - } + err = c.s.PutContract(ctx, con) if err != nil { return fmt.Errorf("Can't save contract: %w", err) } diff --git a/dsp/statemachine/transfer_messages.go b/dsp/statemachine/transfer_messages.go index 05a1562..0b322e4 100644 --- a/dsp/statemachine/transfer_messages.go +++ b/dsp/statemachine/transfer_messages.go @@ -21,7 +21,9 @@ import ( "path" "strings" + "github.com/go-dataspace/run-dsp/dsp/constants" "github.com/go-dataspace/run-dsp/dsp/shared" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/logging" providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" "github.com/google/uuid" @@ -29,14 +31,14 @@ import ( func makeTransferRequestFunction( ctx context.Context, - t *TransferRequest, + t *transfer.Request, cu *url.URL, reqBody []byte, - destinationState TransferRequestState, + destinationState transfer.State, reconciler *Reconciler, ) func() { var id uuid.UUID - if t.GetRole() == DataspaceConsumer { + if t.GetRole() == constants.DataspaceConsumer { id = t.GetConsumerPID() } else { id = t.GetProviderPID() @@ -56,7 +58,7 @@ func makeTransferRequestFunction( func sendTransferRequest(ctx context.Context, tr *TransferRequestNegotiationInitial) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendTransferRequest") transferRequest := shared.TransferRequestMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:TransferRequestMessage", AgreementID: tr.GetAgreementID().URN(), Format: tr.GetFormat(), @@ -78,7 +80,7 @@ func sendTransferRequest(ctx context.Context, tr *TransferRequestNegotiationInit tr.GetTransferRequest(), cu, reqBody, - TransferRequestStates.TRANSFERREQUESTED, + transfer.States.REQUESTED, tr.GetReconciler(), ), nil } @@ -86,7 +88,7 @@ func sendTransferRequest(ctx context.Context, tr *TransferRequestNegotiationInit func sendTransferStart(ctx context.Context, tr *TransferRequestNegotiationRequested) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendTransferStarted") startRequest := shared.TransferStartMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:TransferStartMessage", ProviderPID: tr.GetProviderPID().URN(), ConsumerPID: tr.GetConsumerPID().URN(), @@ -100,7 +102,7 @@ func sendTransferStart(ctx context.Context, tr *TransferRequestNegotiationReques } pid := tr.GetConsumerPID().String() - if tr.GetRole() == DataspaceConsumer { + if tr.GetRole() == constants.DataspaceConsumer { pid = tr.GetProviderPID().String() } cu := cloneURL(tr.GetCallback()) @@ -111,7 +113,7 @@ func sendTransferStart(ctx context.Context, tr *TransferRequestNegotiationReques tr.GetTransferRequest(), cu, reqBody, - TransferRequestStates.STARTED, + transfer.States.STARTED, tr.GetReconciler(), ), nil } @@ -119,7 +121,7 @@ func sendTransferStart(ctx context.Context, tr *TransferRequestNegotiationReques func sendTransferCompletion(ctx context.Context, tr *TransferRequestNegotiationStarted) (func(), error) { ctx, logger := logging.InjectLabels(ctx, "operation", "sendTransferCompletion") startRequest := shared.TransferCompletionMessage{ - Context: dspaceContext, + Context: shared.GetDSPContext(), Type: "dspace:TransferCompletionMessage", ProviderPID: tr.GetProviderPID().URN(), ConsumerPID: tr.GetConsumerPID().URN(), @@ -132,7 +134,7 @@ func sendTransferCompletion(ctx context.Context, tr *TransferRequestNegotiationS } pid := tr.GetConsumerPID().String() - if tr.GetRole() == DataspaceConsumer { + if tr.GetRole() == constants.DataspaceConsumer { pid = tr.GetProviderPID().String() } @@ -144,7 +146,7 @@ func sendTransferCompletion(ctx context.Context, tr *TransferRequestNegotiationS tr.GetTransferRequest(), cu, reqBody, - TransferRequestStates.COMPLETED, + transfer.States.COMPLETED, tr.GetReconciler(), ), nil } diff --git a/dsp/statemachine/transfer_request.go b/dsp/statemachine/transfer_request.go deleted file mode 100644 index 4df35eb..0000000 --- a/dsp/statemachine/transfer_request.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 go-dataspace -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statemachine - -import ( - "fmt" - "net/url" - "slices" - "strconv" - - "github.com/go-dataspace/run-dsp/dsp/shared" - providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" - "github.com/google/uuid" -) - -var validTransferTransitions = map[TransferRequestState][]TransferRequestState{ - TransferRequestStates.TRANSFERINITIAL: { - TransferRequestStates.TRANSFERREQUESTED, - TransferRequestStates.TRANSFERTERMINATED, - }, - TransferRequestStates.TRANSFERREQUESTED: { - TransferRequestStates.STARTED, - TransferRequestStates.TRANSFERTERMINATED, - }, - TransferRequestStates.STARTED: { - TransferRequestStates.SUSPENDED, - TransferRequestStates.COMPLETED, - TransferRequestStates.TRANSFERTERMINATED, - }, - TransferRequestStates.SUSPENDED: { - TransferRequestStates.STARTED, - TransferRequestStates.TRANSFERTERMINATED, - }, - TransferRequestStates.COMPLETED: {}, -} - -type TransferDirection uint8 - -const ( - DirectionUnknown TransferDirection = iota - DirectionPull - DirectionPush -) - -// TransferRequest represents a transfer request and its state. -type TransferRequest struct { - State TransferRequestState - ProviderPID uuid.UUID - ConsumerPID uuid.UUID - AgreementID uuid.UUID - Target string - Format string - Callback *url.URL - Self *url.URL - Role DataspaceRole - PublishInfo *providerv1.PublishInfo - TransferDirection TransferDirection - - ro bool -} - -func (tr *TransferRequest) GetProviderPID() uuid.UUID { return tr.ProviderPID } -func (tr *TransferRequest) GetConsumerPID() uuid.UUID { return tr.ConsumerPID } -func (tr *TransferRequest) GetAgreementID() uuid.UUID { return tr.AgreementID } -func (tr *TransferRequest) GetTarget() string { return tr.Target } -func (tr *TransferRequest) GetFormat() string { return tr.Format } -func (tr *TransferRequest) GetCallback() *url.URL { return tr.Callback } -func (tr *TransferRequest) GetSelf() *url.URL { return tr.Self } -func (tr *TransferRequest) GetState() TransferRequestState { return tr.State } -func (tr *TransferRequest) GetRole() DataspaceRole { return tr.Role } -func (tr *TransferRequest) GetTransferRequest() *TransferRequest { return tr } -func (tr *TransferRequest) GetPublishInfo() *providerv1.PublishInfo { return tr.PublishInfo } -func (tr *TransferRequest) GetTransferDirection() TransferDirection { return tr.TransferDirection } - -func (tr *TransferRequest) SetReadOnly() { tr.ro = true } -func (tr *TransferRequest) ReadOnly() bool { return tr.ro } - -func (tr *TransferRequest) StorageKey() []byte { - id := tr.ConsumerPID - if tr.Role == DataspaceProvider { - id = tr.ProviderPID - } - return MkTransferKey(id, tr.Role) -} - -func (tr *TransferRequest) SetState(state TransferRequestState) error { - if !slices.Contains(validTransferTransitions[tr.State], state) { - return fmt.Errorf("can't transition from %s to %s", tr.State, state) - } - tr.State = state - return nil -} - -func (tr *TransferRequest) GetTransferProcess() shared.TransferProcess { - return shared.TransferProcess{ - Context: dspaceContext, - Type: "dspace:TransferProcess", - ProviderPID: tr.ProviderPID.URN(), - ConsumerPID: tr.ConsumerPID.URN(), - State: tr.State.String(), - } -} - -func (tr *TransferRequest) SetProviderPID(id uuid.UUID) { tr.ProviderPID = id } - -func MkTransferKey(id uuid.UUID, role DataspaceRole) []byte { - return []byte("transfer-" + id.String() + "-" + strconv.Itoa(int(role))) -} diff --git a/dsp/statemachine/transfer_request_transitions.go b/dsp/statemachine/transfer_request_transitions.go index 8530fbe..bd77f1d 100644 --- a/dsp/statemachine/transfer_request_transitions.go +++ b/dsp/statemachine/transfer_request_transitions.go @@ -21,7 +21,9 @@ import ( "net/url" "strings" + "github.com/go-dataspace/run-dsp/dsp/constants" "github.com/go-dataspace/run-dsp/dsp/shared" + "github.com/go-dataspace/run-dsp/dsp/transfer" providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" "github.com/google/uuid" ) @@ -36,12 +38,12 @@ type TransferRequester interface { GetFormat() string GetCallback() *url.URL GetSelf() *url.URL - GetState() TransferRequestState - GetRole() DataspaceRole - SetState(state TransferRequestState) error - GetTransferRequest() *TransferRequest + GetState() transfer.State + GetRole() constants.DataspaceRole + SetState(state transfer.State) error + GetTransferRequest() *transfer.Request GetPublishInfo() *providerv1.PublishInfo - GetTransferDirection() TransferDirection + GetTransferDirection() transfer.Direction GetTransferProcess() shared.TransferProcess } @@ -49,13 +51,12 @@ type TransferRequestNegotiationState interface { TransferRequester Recv(ctx context.Context, message any) (TransferRequestNegotiationState, error) Send(ctx context.Context) (func(), error) - GetArchiver() Archiver GetProvider() providerv1.ProviderServiceClient GetReconciler() *Reconciler } type TransferRequestNegotiationInitial struct { - *TransferRequest + *transfer.Request stateMachineDeps } @@ -70,9 +71,9 @@ func (tr *TransferRequestNegotiationInitial) Recv( if err != nil { return nil, fmt.Errorf("could not find target: %w", err) } - tr.ProviderPID = uuid.New() + tr.SetProviderPID(uuid.New()) return verifyAndTransformTransfer( - ctx, tr, tr.ProviderPID.URN(), t.ConsumerPID, TransferRequestStates.TRANSFERREQUESTED) + tr, tr.GetProviderPID().URN(), t.ConsumerPID, transfer.States.REQUESTED) default: return nil, fmt.Errorf("invalid message type") } @@ -83,7 +84,7 @@ func (tr *TransferRequestNegotiationInitial) Send(ctx context.Context) (func(), } type TransferRequestNegotiationRequested struct { - *TransferRequest + *transfer.Request stateMachineDeps } @@ -97,18 +98,19 @@ func (tr *TransferRequestNegotiationRequested) Recv( if err != nil { return nil, fmt.Errorf("invalid UUID for provider PID: %w", err) } - tr.ProviderPID = u + tr.SetProviderPID(u) } - if tr.PublishInfo == nil { + if tr.GetPublishInfo() == nil { var err error - tr.PublishInfo, err = dataAddressToPublishInfo(t.DataAddress) + pi, err := dataAddressToPublishInfo(t.DataAddress) if err != nil { return nil, fmt.Errorf("invalid dataAddress supplied: %w", err) } + tr.SetPublishInfo(pi) } - return verifyAndTransformTransfer(ctx, tr, t.ProviderPID, t.ConsumerPID, TransferRequestStates.STARTED) + return verifyAndTransformTransfer(tr, t.ProviderPID, t.ConsumerPID, transfer.States.STARTED) case shared.TransferTerminationMessage: - return verifyAndTransformTransfer(ctx, tr, t.ProviderPID, t.ConsumerPID, TransferRequestStates.TRANSFERTERMINATED) + return verifyAndTransformTransfer(tr, t.ProviderPID, t.ConsumerPID, transfer.States.TERMINATED) default: return nil, fmt.Errorf("invalid message type") } @@ -116,7 +118,7 @@ func (tr *TransferRequestNegotiationRequested) Recv( func (tr *TransferRequestNegotiationRequested) Send(ctx context.Context) (func(), error) { switch tr.GetTransferDirection() { - case DirectionPull: + case transfer.DirectionPull: resp, err := tr.GetProvider().PublishDataset(ctx, &providerv1.PublishDatasetRequest{ DatasetId: tr.GetTarget(), PublishId: tr.GetProviderPID().String(), @@ -124,11 +126,11 @@ func (tr *TransferRequestNegotiationRequested) Send(ctx context.Context) (func() if err != nil { return func() {}, err } - tr.PublishInfo = resp.PublishInfo - case DirectionPush: + tr.SetPublishInfo(resp.PublishInfo) + case transfer.DirectionPush: // TODO: Signal provider to start uploading dataset here. return func() {}, fmt.Errorf("push flow: %w", ErrNotImplemented) - case DirectionUnknown: + case transfer.DirectionUnknown: return func() {}, fmt.Errorf("unknown transfer direction") default: panic("unexpected statemachine.TransferDirection") @@ -138,7 +140,7 @@ func (tr *TransferRequestNegotiationRequested) Send(ctx context.Context) (func() } type TransferRequestNegotiationStarted struct { - *TransferRequest + *transfer.Request stateMachineDeps } @@ -151,9 +153,9 @@ func (tr *TransferRequestNegotiationStarted) Recv( if err != nil { return nil, err } - return verifyAndTransformTransfer(ctx, tr, t.ProviderPID, t.ConsumerPID, TransferRequestStates.COMPLETED) + return verifyAndTransformTransfer(tr, t.ProviderPID, t.ConsumerPID, transfer.States.COMPLETED) case shared.TransferTerminationMessage: - return verifyAndTransformTransfer(ctx, tr, t.ProviderPID, t.ConsumerPID, TransferRequestStates.TRANSFERTERMINATED) + return verifyAndTransformTransfer(tr, t.ProviderPID, t.ConsumerPID, transfer.States.TERMINATED) default: return nil, fmt.Errorf("invalid message type") } @@ -169,8 +171,8 @@ func (tr *TransferRequestNegotiationStarted) Send(ctx context.Context) (func(), func unpublishTransfer(ctx context.Context, tr TransferRequestNegotiationState) error { switch tr.GetTransferDirection() { - case DirectionPull: - if tr.GetRole() == DataspaceProvider { + case transfer.DirectionPull: + if tr.GetRole() == constants.DataspaceProvider { _, err := tr.GetProvider().UnpublishDataset(ctx, &providerv1.UnpublishDatasetRequest{ PublishId: tr.GetProviderPID().String(), }) @@ -178,10 +180,9 @@ func unpublishTransfer(ctx context.Context, tr TransferRequestNegotiationState) return err } } - case DirectionPush: - + case transfer.DirectionPush: return fmt.Errorf("push flow: %w", ErrNotImplemented) - case DirectionUnknown: + case transfer.DirectionUnknown: return fmt.Errorf("unknown transfer direction") default: panic("unexpected statemachine.TransferDirection") @@ -190,12 +191,12 @@ func unpublishTransfer(ctx context.Context, tr TransferRequestNegotiationState) } type TransferRequestNegotiationSuspended struct { - *TransferRequest + *transfer.Request stateMachineDeps } type TransferRequestNegotiationCompleted struct { - *TransferRequest + *transfer.Request stateMachineDeps } @@ -210,7 +211,7 @@ func (tr *TransferRequestNegotiationCompleted) Send(ctx context.Context) (func() } type TransferRequestNegotiationTerminated struct { - *TransferRequest + *transfer.Request stateMachineDeps } @@ -224,67 +225,21 @@ func (tr *TransferRequestNegotiationTerminated) Send(ctx context.Context) (func( return func() {}, nil } -func NewTransferRequest( - ctx context.Context, - store Archiver, - provider providerv1.ProviderServiceClient, - reconciler *Reconciler, - consumerPID, agreementID uuid.UUID, - format string, - callback, self *url.URL, - role DataspaceRole, - state TransferRequestState, - publishInfo *providerv1.PublishInfo, -) (TransferRequestNegotiationState, error) { - agreement, err := store.GetAgreement(ctx, agreementID) - if err != nil { - return nil, fmt.Errorf("no agreement found") - } - targetID, err := shared.URNtoRawID(agreement.Target) - if err != nil { - return nil, fmt.Errorf("couldn't parse target URN: %w", err) - } - traReq := &TransferRequest{ - State: state, - ConsumerPID: consumerPID, - AgreementID: agreementID, - Target: targetID, - Format: format, - Callback: callback, - Self: self, - Role: role, - PublishInfo: publishInfo, - TransferDirection: DirectionPush, - } - if publishInfo == nil { - traReq.TransferDirection = DirectionPull - } - if role == DataspaceConsumer { - err = store.PutConsumerTransfer(ctx, traReq) - } else { - err = store.PutProviderTransfer(ctx, traReq) - } - if err != nil { - return nil, err - } - return GetTransferRequestNegotiation(store, traReq, provider, reconciler), nil -} - func GetTransferRequestNegotiation( - a Archiver, tr *TransferRequest, p providerv1.ProviderServiceClient, r *Reconciler, + tr *transfer.Request, p providerv1.ProviderServiceClient, r *Reconciler, ) TransferRequestNegotiationState { - deps := stateMachineDeps{a: a, p: p, r: r} + deps := stateMachineDeps{p: p, r: r} switch tr.GetState() { - case TransferRequestStates.TRANSFERINITIAL: - return &TransferRequestNegotiationInitial{TransferRequest: tr, stateMachineDeps: deps} - case TransferRequestStates.TRANSFERREQUESTED: - return &TransferRequestNegotiationRequested{TransferRequest: tr, stateMachineDeps: deps} - case TransferRequestStates.STARTED: - return &TransferRequestNegotiationStarted{TransferRequest: tr, stateMachineDeps: deps} - case TransferRequestStates.COMPLETED: - return &TransferRequestNegotiationCompleted{TransferRequest: tr, stateMachineDeps: deps} - case TransferRequestStates.TRANSFERTERMINATED: - return &TransferRequestNegotiationTerminated{TransferRequest: tr, stateMachineDeps: deps} + case transfer.States.INITIAL: + return &TransferRequestNegotiationInitial{Request: tr, stateMachineDeps: deps} + case transfer.States.REQUESTED: + return &TransferRequestNegotiationRequested{Request: tr, stateMachineDeps: deps} + case transfer.States.STARTED: + return &TransferRequestNegotiationStarted{Request: tr, stateMachineDeps: deps} + case transfer.States.COMPLETED: + return &TransferRequestNegotiationCompleted{Request: tr, stateMachineDeps: deps} + case transfer.States.TERMINATED: + return &TransferRequestNegotiationTerminated{Request: tr, stateMachineDeps: deps} default: panic(fmt.Sprintf("No transition found for state %s", tr.GetState())) } @@ -330,10 +285,9 @@ func makeEndpointPropertyMap(p []shared.EndpointProperty) (map[string]string, er } func verifyAndTransformTransfer( - ctx context.Context, tr TransferRequestNegotiationState, providerPID, consumerPID string, - targetState TransferRequestState, + targetState transfer.State, ) (TransferRequestNegotiationState, error) { if tr.GetProviderPID().URN() != strings.ToLower(providerPID) { return nil, fmt.Errorf( @@ -352,15 +306,6 @@ func verifyAndTransformTransfer( if err := tr.SetState(targetState); err != nil { return nil, fmt.Errorf("could not set state: %w", err) } - var err error - if tr.GetRole() == DataspaceConsumer { - err = tr.GetArchiver().PutConsumerTransfer(ctx, tr.GetTransferRequest()) - } else { - err = tr.GetArchiver().PutProviderTransfer(ctx, tr.GetTransferRequest()) - } - if err != nil { - return nil, fmt.Errorf("failed to save contract: %w", err) - } return GetTransferRequestNegotiation( - tr.GetArchiver(), tr.GetTransferRequest(), tr.GetProvider(), tr.GetReconciler()), nil + tr.GetTransferRequest(), tr.GetProvider(), tr.GetReconciler()), nil } diff --git a/dsp/statemachine/transfer_statemachine_test.go b/dsp/statemachine/transfer_statemachine_test.go index bcc7c96..b63f8cc 100644 --- a/dsp/statemachine/transfer_statemachine_test.go +++ b/dsp/statemachine/transfer_statemachine_test.go @@ -19,8 +19,11 @@ import ( "testing" "time" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/persistence/badger" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/dsp/statemachine" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/logging" mockprovider "github.com/go-dataspace/run-dsp/mocks/github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" "github.com/go-dataspace/run-dsp/odrl" @@ -30,7 +33,6 @@ import ( var agreementID = uuid.MustParse("e1c68180-de68-428d-9853-7d4dd3c66904") -//nolint:funlen func TestTransferTermination(t *testing.T) { t.Parallel() @@ -44,36 +46,34 @@ func TestTransferTermination(t *testing.T) { ctx, done := context.WithCancel(ctx) defer done() - store := statemachine.NewMemoryArchiver() + store, err := badger.New(ctx, true, "") + assert.Nil(t, err) requester := &MockRequester{} mockProvider := mockprovider.NewMockProviderServiceClient(t) + err = store.PutAgreement(ctx, &agreement) + assert.Nil(t, err) reconciler := statemachine.NewReconciler(ctx, requester, store) reconciler.Run() - for _, role := range []statemachine.DataspaceRole{ - statemachine.DataspaceConsumer, - statemachine.DataspaceProvider, + for _, role := range []constants.DataspaceRole{ + constants.DataspaceConsumer, + constants.DataspaceProvider, } { - for _, state := range []statemachine.TransferRequestState{ - statemachine.TransferRequestStates.TRANSFERREQUESTED, - statemachine.TransferRequestStates.STARTED, + for _, state := range []transfer.State{ + transfer.States.REQUESTED, + transfer.States.STARTED, } { - err := store.PutAgreement(ctx, &agreement) - assert.Nil(t, err) - - pState, err := statemachine.NewTransferRequest( - ctx, - store, mockProvider, reconciler, - consumerPID, agreementID, + transReq := transfer.New( + consumerPID, &agreement, "HTTP_PULL", providerCallback, consumerCallback, role, state, nil, ) - assert.Nil(t, err) + pState := statemachine.GetTransferRequestNegotiation(transReq, mockProvider, reconciler) pState.GetTransferRequest().SetProviderPID(providerPID) transferMsg := shared.TransferTerminationMessage{ @@ -90,15 +90,7 @@ func TestTransferTermination(t *testing.T) { assert.Nil(t, err) _, err = next.Send(ctx) assert.Nil(t, err) - var transfer *statemachine.TransferRequest - switch role { - case statemachine.DataspaceProvider: - transfer, err = store.GetProviderTransfer(ctx, providerPID) - case statemachine.DataspaceConsumer: - transfer, err = store.GetConsumerTransfer(ctx, consumerPID) - } - assert.Nil(t, err) - assert.Equal(t, statemachine.TransferRequestStates.TRANSFERTERMINATED, transfer.GetState()) + assert.Equal(t, transfer.States.TERMINATED, next.GetTransferRequest().GetState()) } } } diff --git a/dsp/statemachine/transferrequeststates_enums.go b/dsp/statemachine/transferrequeststates_enums.go deleted file mode 100644 index 44c5c00..0000000 --- a/dsp/statemachine/transferrequeststates_enums.go +++ /dev/null @@ -1,176 +0,0 @@ -// Code generated by goenums. DO NOT EDIT. -// This file was generated by github.com/zarldev/goenums -// using the command: -// goenums transfer_request_state.go - -package statemachine - -import ( - "bytes" - "database/sql/driver" - "fmt" - "strconv" -) - -type TransferRequestState struct { - transferRequestState -} - -type transferrequeststatesContainer struct { - TRANSFERINITIAL TransferRequestState - TRANSFERREQUESTED TransferRequestState - STARTED TransferRequestState - SUSPENDED TransferRequestState - COMPLETED TransferRequestState - TRANSFERTERMINATED TransferRequestState -} - -var TransferRequestStates = transferrequeststatesContainer{ - TRANSFERINITIAL: TransferRequestState{ - transferRequestState: transferInitial, - }, - TRANSFERREQUESTED: TransferRequestState{ - transferRequestState: transferRequested, - }, - STARTED: TransferRequestState{ - transferRequestState: started, - }, - SUSPENDED: TransferRequestState{ - transferRequestState: suspended, - }, - COMPLETED: TransferRequestState{ - transferRequestState: completed, - }, - TRANSFERTERMINATED: TransferRequestState{ - transferRequestState: transferTerminated, - }, -} - -func (c transferrequeststatesContainer) All() []TransferRequestState { - return []TransferRequestState{ - c.TRANSFERINITIAL, - c.TRANSFERREQUESTED, - c.STARTED, - c.SUSPENDED, - c.COMPLETED, - c.TRANSFERTERMINATED, - } -} - -var invalidTransferRequestState = TransferRequestState{} - -func ParseTransferRequestState(a any) (TransferRequestState, error) { - res := invalidTransferRequestState - switch v := a.(type) { - case TransferRequestState: - return v, nil - case []byte: - res = stringToTransferRequestState(string(v)) - case string: - res = stringToTransferRequestState(v) - case fmt.Stringer: - res = stringToTransferRequestState(v.String()) - case int: - res = intToTransferRequestState(v) - case int64: - res = intToTransferRequestState(int(v)) - case int32: - res = intToTransferRequestState(int(v)) - } - return res, nil -} - -func stringToTransferRequestState(s string) TransferRequestState { - switch s { - case "INITIAL": - return TransferRequestStates.TRANSFERINITIAL - case "dspace:REQUESTED": - return TransferRequestStates.TRANSFERREQUESTED - case "dspace:STARTED": - return TransferRequestStates.STARTED - case "dspace:SUSPENDED": - return TransferRequestStates.SUSPENDED - case "dspace:COMPLETED": - return TransferRequestStates.COMPLETED - case "dspace:TERMINATED": - return TransferRequestStates.TRANSFERTERMINATED - } - return invalidTransferRequestState -} - -func intToTransferRequestState(i int) TransferRequestState { - if i < 0 || i >= len(TransferRequestStates.All()) { - return invalidTransferRequestState - } - return TransferRequestStates.All()[i] -} - -func ExhaustiveTransferRequestStates(f func(TransferRequestState)) { - for _, p := range TransferRequestStates.All() { - f(p) - } -} - -var validTransferRequestStates = map[TransferRequestState]bool{ - TransferRequestStates.TRANSFERINITIAL: true, - TransferRequestStates.TRANSFERREQUESTED: true, - TransferRequestStates.STARTED: true, - TransferRequestStates.SUSPENDED: true, - TransferRequestStates.COMPLETED: true, - TransferRequestStates.TRANSFERTERMINATED: true, -} - -func (p TransferRequestState) IsValid() bool { - return validTransferRequestStates[p] -} - -func (p TransferRequestState) MarshalJSON() ([]byte, error) { - return []byte(`"` + p.String() + `"`), nil -} - -func (p *TransferRequestState) UnmarshalJSON(b []byte) error { - b = bytes.Trim(bytes.Trim(b, `"`), ` `) - newp, err := ParseTransferRequestState(b) - if err != nil { - return err - } - *p = newp - return nil -} - -func (p *TransferRequestState) Scan(value any) error { - newp, err := ParseTransferRequestState(value) - if err != nil { - return err - } - *p = newp - return nil -} - -func (p TransferRequestState) Value() (driver.Value, error) { - return p.String(), nil -} - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the goenums command to generate them again. - // Does not identify newly added constant values unless order changes - var x [1]struct{} - _ = x[transferInitial-0] - _ = x[transferRequested-1] - _ = x[started-2] - _ = x[suspended-3] - _ = x[completed-4] - _ = x[transferTerminated-5] -} - -const _transferrequeststates_name = "INITIALdspace:REQUESTEDdspace:STARTEDdspace:SUSPENDEDdspace:COMPLETEDdspace:TERMINATED" - -var _transferrequeststates_index = [...]uint16{0, 7, 23, 37, 53, 69, 86} - -func (i transferRequestState) String() string { - if i < 0 || i >= transferRequestState(len(_transferrequeststates_index)-1) { - return "transferrequeststates(" + (strconv.FormatInt(int64(i), 10) + ")") - } - return _transferrequeststates_name[_transferrequeststates_index[i]:_transferrequeststates_index[i+1]] -} diff --git a/dsp/statemachine/transferrequeststates_gob.go b/dsp/statemachine/transferrequeststates_gob.go deleted file mode 100644 index ea2e2e8..0000000 --- a/dsp/statemachine/transferrequeststates_gob.go +++ /dev/null @@ -1,15 +0,0 @@ -package statemachine - -func (p TransferRequestState) GobEncode() ([]byte, error) { - return []byte(p.String()), nil -} - -func (p *TransferRequestState) GobDecode(b []byte) error { - newp, err := ParseTransferRequestState(b) - if err != nil { - return err - } - - *p = newp - return nil -} diff --git a/dsp/transfer/request.go b/dsp/transfer/request.go new file mode 100644 index 0000000..fcf7725 --- /dev/null +++ b/dsp/transfer/request.go @@ -0,0 +1,243 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transfer + +import ( + "bytes" + "encoding/gob" + "fmt" + "net/url" + "slices" + "strconv" + + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/shared" + "github.com/go-dataspace/run-dsp/odrl" + providerv1 "github.com/go-dataspace/run-dsrpc/gen/go/dsp/v1alpha1" + "github.com/google/uuid" +) + +var validTransferTransitions = map[State][]State{ + States.INITIAL: { + States.REQUESTED, + States.TERMINATED, + }, + States.REQUESTED: { + States.STARTED, + States.TERMINATED, + }, + States.STARTED: { + States.SUSPENDED, + States.COMPLETED, + States.TERMINATED, + }, + States.SUSPENDED: { + States.STARTED, + States.TERMINATED, + }, + States.COMPLETED: {}, +} + +type Direction uint8 + +const ( + DirectionUnknown Direction = iota + DirectionPull + DirectionPush +) + +// Request represents a transfer request and its state. +type Request struct { + state State + providerPID uuid.UUID + consumerPID uuid.UUID + agreementID uuid.UUID + target string + format string + callback *url.URL + self *url.URL + role constants.DataspaceRole + publishInfo *providerv1.PublishInfo + transferDirection Direction + + ro bool + modified bool +} + +type storableRequest struct { + State State + ProviderPID uuid.UUID + ConsumerPID uuid.UUID + AgreementID uuid.UUID + Target string + Format string + Callback *url.URL + Self *url.URL + Role constants.DataspaceRole + PublishInfo *providerv1.PublishInfo + TransferDirection Direction +} + +func New( + consumerPID uuid.UUID, + agreement *odrl.Agreement, + format string, + callback, self *url.URL, + role constants.DataspaceRole, + state State, + publishInfo *providerv1.PublishInfo, +) *Request { + targetID, err := shared.URNtoRawID(agreement.Target) + if err != nil { + panic("Misformed agreement, this means database corruption") + } + t := &Request{ + state: state, + consumerPID: consumerPID, + agreementID: uuid.MustParse(agreement.ID), + target: targetID, + format: format, + callback: callback, + self: self, + role: role, + publishInfo: publishInfo, + transferDirection: DirectionPush, + } + if publishInfo == nil { + t.transferDirection = DirectionPull + } + return t +} + +func FromBytes(b []byte) (*Request, error) { + var sr storableRequest + r := bytes.NewReader(b) + dec := gob.NewDecoder(r) + if err := dec.Decode(&sr); err != nil { + return nil, fmt.Errorf("could not decode bytes into storableRequest: %w", err) + } + return &Request{ + state: sr.State, + providerPID: sr.ProviderPID, + consumerPID: sr.ConsumerPID, + agreementID: sr.AgreementID, + target: sr.Target, + format: sr.Format, + callback: sr.Callback, + self: sr.Self, + role: sr.Role, + publishInfo: sr.PublishInfo, + transferDirection: sr.TransferDirection, + }, nil +} + +func GenerateKey(id uuid.UUID, role constants.DataspaceRole) []byte { + return []byte("transfer-" + id.String() + "-" + strconv.Itoa(int(role))) +} + +// Request getters. +func (tr *Request) GetProviderPID() uuid.UUID { return tr.providerPID } +func (tr *Request) GetConsumerPID() uuid.UUID { return tr.consumerPID } +func (tr *Request) GetAgreementID() uuid.UUID { return tr.agreementID } +func (tr *Request) GetTarget() string { return tr.target } +func (tr *Request) GetFormat() string { return tr.format } +func (tr *Request) GetCallback() *url.URL { return tr.callback } +func (tr *Request) GetSelf() *url.URL { return tr.self } +func (tr *Request) GetState() State { return tr.state } +func (tr *Request) GetRole() constants.DataspaceRole { return tr.role } +func (tr *Request) GetTransferRequest() *Request { return tr } +func (tr *Request) GetPublishInfo() *providerv1.PublishInfo { return tr.publishInfo } +func (tr *Request) GetTransferDirection() Direction { + return tr.transferDirection +} + +// Request setters, these will panic when the transfer is RO. +func (tr *Request) SetPublishInfo(pi *providerv1.PublishInfo) { + tr.panicRO() + tr.publishInfo = pi + tr.modify() +} + +func (tr *Request) SetProviderPID(id uuid.UUID) { + tr.panicRO() + tr.providerPID = id + tr.modify() +} + +func (tr *Request) SetState(state State) error { + tr.panicRO() + if !slices.Contains(validTransferTransitions[tr.state], state) { + return fmt.Errorf("can't transition from %s to %s", tr.state, state) + } + tr.state = state + tr.modify() + return nil +} + +// Properties that decisions are based on. +func (tr *Request) ReadOnly() bool { return tr.ro } +func (tr *Request) Modified() bool { return tr.modified } +func (tr *Request) StorageKey() []byte { + id := tr.consumerPID + if tr.role == constants.DataspaceProvider { + id = tr.providerPID + } + return GenerateKey(id, tr.role) +} + +// Property setters. +func (tr *Request) SetReadOnly() { tr.ro = true } + +func (tr *Request) ToBytes() ([]byte, error) { + s := storableRequest{ + State: tr.state, + ProviderPID: tr.providerPID, + ConsumerPID: tr.consumerPID, + AgreementID: tr.agreementID, + Target: tr.target, + Format: tr.format, + Callback: tr.callback, + Self: tr.self, + Role: tr.role, + PublishInfo: tr.publishInfo, + TransferDirection: tr.transferDirection, + } + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(s); err != nil { + return nil, fmt.Errorf("could not encode negotiation: %w", err) + } + return buf.Bytes(), nil +} + +func (tr *Request) GetTransferProcess() shared.TransferProcess { + return shared.TransferProcess{ + Context: shared.GetDSPContext(), + Type: "dspace:TransferProcess", + ProviderPID: tr.providerPID.URN(), + ConsumerPID: tr.consumerPID.URN(), + State: tr.state.String(), + } +} + +func (tr *Request) panicRO() { + if tr.ro { + panic("Trying to write to a read-only request, this is certainly a bug.") + } +} + +func (tr *Request) modify() { + tr.modified = true +} diff --git a/dsp/statemachine/transfer_request_state.go b/dsp/transfer/state.go similarity index 59% rename from dsp/statemachine/transfer_request_state.go rename to dsp/transfer/state.go index 36ac711..5e1be37 100644 --- a/dsp/statemachine/transfer_request_state.go +++ b/dsp/transfer/state.go @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -package statemachine +package transfer -type transferRequestState int +type state int //go:generate goenums transfer_request_state.go const ( - transferInitial transferRequestState = iota // INITIAL - transferRequested // dspace:REQUESTED - started // dspace:STARTED - suspended // dspace:SUSPENDED - completed // dspace:COMPLETED - transferTerminated // dspace:TERMINATED + initial state = iota // INITIAL + requested // dspace:REQUESTED + started // dspace:STARTED + suspended // dspace:SUSPENDED + completed // dspace:COMPLETED + terminated // dspace:TERMINATED ) diff --git a/dsp/transfer/states_enums.go b/dsp/transfer/states_enums.go new file mode 100644 index 0000000..44050b2 --- /dev/null +++ b/dsp/transfer/states_enums.go @@ -0,0 +1,176 @@ +// Code generated by goenums. DO NOT EDIT. +// This file was generated by github.com/zarldev/goenums +// using the command: +// goenums ./dsp/transfer/state.go + +package transfer + +import ( + "bytes" + "database/sql/driver" + "fmt" + "strconv" +) + +type State struct { + state +} + +type statesContainer struct { + INITIAL State + REQUESTED State + STARTED State + SUSPENDED State + COMPLETED State + TERMINATED State +} + +var States = statesContainer{ + INITIAL: State{ + state: initial, + }, + REQUESTED: State{ + state: requested, + }, + STARTED: State{ + state: started, + }, + SUSPENDED: State{ + state: suspended, + }, + COMPLETED: State{ + state: completed, + }, + TERMINATED: State{ + state: terminated, + }, +} + +func (c statesContainer) All() []State { + return []State{ + c.INITIAL, + c.REQUESTED, + c.STARTED, + c.SUSPENDED, + c.COMPLETED, + c.TERMINATED, + } +} + +var invalidState = State{} + +func ParseState(a any) (State, error) { + res := invalidState + switch v := a.(type) { + case State: + return v, nil + case []byte: + res = stringToState(string(v)) + case string: + res = stringToState(v) + case fmt.Stringer: + res = stringToState(v.String()) + case int: + res = intToState(v) + case int64: + res = intToState(int(v)) + case int32: + res = intToState(int(v)) + } + return res, nil +} + +func stringToState(s string) State { + switch s { + case "INITIAL": + return States.INITIAL + case "dspace:REQUESTED": + return States.REQUESTED + case "dspace:STARTED": + return States.STARTED + case "dspace:SUSPENDED": + return States.SUSPENDED + case "dspace:COMPLETED": + return States.COMPLETED + case "dspace:TERMINATED": + return States.TERMINATED + } + return invalidState +} + +func intToState(i int) State { + if i < 0 || i >= len(States.All()) { + return invalidState + } + return States.All()[i] +} + +func ExhaustiveStates(f func(State)) { + for _, p := range States.All() { + f(p) + } +} + +var validStates = map[State]bool{ + States.INITIAL: true, + States.REQUESTED: true, + States.STARTED: true, + States.SUSPENDED: true, + States.COMPLETED: true, + States.TERMINATED: true, +} + +func (p State) IsValid() bool { + return validStates[p] +} + +func (p State) MarshalJSON() ([]byte, error) { + return []byte(`"` + p.String() + `"`), nil +} + +func (p *State) UnmarshalJSON(b []byte) error { + b = bytes.Trim(bytes.Trim(b, `"`), ` `) + newp, err := ParseState(b) + if err != nil { + return err + } + *p = newp + return nil +} + +func (p *State) Scan(value any) error { + newp, err := ParseState(value) + if err != nil { + return err + } + *p = newp + return nil +} + +func (p State) Value() (driver.Value, error) { + return p.String(), nil +} + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the goenums command to generate them again. + // Does not identify newly added constant values unless order changes + var x [1]struct{} + _ = x[initial-0] + _ = x[requested-1] + _ = x[started-2] + _ = x[suspended-3] + _ = x[completed-4] + _ = x[terminated-5] +} + +const _states_name = "INITIALdspace:REQUESTEDdspace:STARTEDdspace:SUSPENDEDdspace:COMPLETEDdspace:TERMINATED" + +var _states_index = [...]uint16{0, 7, 23, 37, 53, 69, 86} + +func (i state) String() string { + if i < 0 || i >= state(len(_states_index)-1) { + return "states(" + (strconv.FormatInt(int64(i), 10) + ")") + } + return _states_name[_states_index[i]:_states_index[i+1]] +} diff --git a/dsp/transfer/states_gob.go b/dsp/transfer/states_gob.go new file mode 100644 index 0000000..feaa153 --- /dev/null +++ b/dsp/transfer/states_gob.go @@ -0,0 +1,15 @@ +package transfer + +func (p State) GobEncode() ([]byte, error) { + return []byte(p.String()), nil +} + +func (p *State) GobDecode(b []byte) error { + newp, err := ParseState(b) + if err != nil { + return err + } + + *p = newp + return nil +} diff --git a/dsp/transfer_handlers.go b/dsp/transfer_handlers.go index 33be7fb..43c3d50 100644 --- a/dsp/transfer_handlers.go +++ b/dsp/transfer_handlers.go @@ -15,20 +15,24 @@ package dsp import ( + "context" "fmt" "io" "net/http" "net/url" + "github.com/go-dataspace/run-dsp/dsp/constants" + "github.com/go-dataspace/run-dsp/dsp/persistence" "github.com/go-dataspace/run-dsp/dsp/shared" "github.com/go-dataspace/run-dsp/dsp/statemachine" + "github.com/go-dataspace/run-dsp/dsp/transfer" "github.com/go-dataspace/run-dsp/logging" "github.com/google/uuid" ) type TransferError struct { status int - transfer *statemachine.TransferRequest + transfer *transfer.Request dspCode string reason string err string @@ -60,7 +64,7 @@ func (te TransferError) ConsumerPID() string { } func transferError( - err string, statusCode int, dspCode string, reason string, transfer *statemachine.TransferRequest, + err string, statusCode int, dspCode string, reason string, transfer *transfer.Request, ) TransferError { return TransferError{ status: statusCode, @@ -78,7 +82,7 @@ func (dh *dspHandlers) providerTransferProcessHandler(w http.ResponseWriter, req return transferError("invalid provider ID", http.StatusBadRequest, "400", "Invalid provider PID", nil) } - contract, err := dh.store.GetProviderTransfer(req.Context(), providerPID) + contract, err := dh.store.GetTransferRW(req.Context(), providerPID, constants.DataspaceProvider) if err != nil { return contractError(err.Error(), http.StatusNotFound, "404", "TransferRequest not found", nil) } @@ -89,7 +93,6 @@ func (dh *dspHandlers) providerTransferProcessHandler(w http.ResponseWriter, req } func (dh *dspHandlers) providerTransferRequestHandler(w http.ResponseWriter, req *http.Request) error { - logger := logging.Extract(req.Context()) transferReq, err := shared.DecodeValid[shared.TransferRequestMessage](req) if err != nil { return transferError(fmt.Sprintf("invalid request message: %s", err.Error()), @@ -108,50 +111,38 @@ func (dh *dspHandlers) providerTransferRequestHandler(w http.ResponseWriter, req http.StatusBadRequest, "400", "Invalid request: agreement ID is not a UUID", nil) } + agreement, err := dh.store.GetAgreement(req.Context(), agreementID) + if err != nil { + return transferError(fmt.Sprintf("Could not get agreement with ID %s: %s", agreementID, err), + http.StatusNotFound, "404", "Invalid request: Agreement not found", nil) + } + cbURL, err := url.Parse(transferReq.CallbackAddress) if err != nil { return transferError(fmt.Sprintf("Invalid callback URL %s: %s", transferReq.CallbackAddress, err.Error()), http.StatusBadRequest, "400", "Invalid request: Non-valid callback URL.", nil) } - pState, err := statemachine.NewTransferRequest( - req.Context(), - dh.store, dh.provider, dh.reconciler, - consumerPID, agreementID, + request := transfer.New( + consumerPID, + agreement, transferReq.Format, - cbURL, dh.selfURL, - statemachine.DataspaceProvider, - statemachine.TransferRequestStates.TRANSFERINITIAL, + cbURL, + dh.selfURL, + constants.DataspaceProvider, + transfer.States.INITIAL, nil, ) - if err != nil { - return transferError(fmt.Sprintf("couldn't create transfer request: %s", err.Error()), - http.StatusInternalServerError, "500", "Failed to create transfer request", nil) - } - nextState, err := pState.Recv(req.Context(), transferReq) - if err != nil { - return transferError( - fmt.Sprintf("couldn't receive message: %s", err.Error()), - http.StatusBadRequest, "400", "Invalid request", pState.GetTransferRequest()) - } - - apply, err := nextState.Send(req.Context()) - if err != nil { - return transferError(fmt.Sprintf("couldn't progress to next state: %s", err.Error()), - http.StatusInternalServerError, "500", "Not able to progress state", nextState.GetTransferRequest()) + if err := storeRequest(req.Context(), dh.store, request); err != nil { + return err } - if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetTransferProcess()); err != nil { - logger.Error("Couldn't serve response", "err", err) - } - go apply() - return nil + return processTransferMessage(dh, w, req, request.GetRole(), request.GetProviderPID(), true, transferReq) } -//nolint:cyclop func progressTransferState[T any]( - dh *dspHandlers, w http.ResponseWriter, req *http.Request, role statemachine.DataspaceRole, + dh *dspHandlers, w http.ResponseWriter, req *http.Request, role constants.DataspaceRole, rawPID string, autoProgress bool, ) error { logger := logging.Extract(req.Context()) @@ -161,29 +152,32 @@ func progressTransferState[T any]( http.StatusBadRequest, "400", "Invalid request: PID is not a UUID", nil) } - var transfer *statemachine.TransferRequest - switch role { - case statemachine.DataspaceConsumer: - transfer, err = dh.store.GetConsumerTransfer(req.Context(), pid) - case statemachine.DataspaceProvider: - transfer, err = dh.store.GetProviderTransfer(req.Context(), pid) - default: - panic(fmt.Sprintf("unexpected statemachine.TransferRole: %#v", role)) - } - if err != nil { - return transferError(fmt.Sprintf("%d transfer request %s not found: %s", role, pid, err), - http.StatusNotFound, "404", "Transfer request not found", nil) - } - msg, err := shared.DecodeValid[T](req) if err != nil { return transferError(fmt.Sprintf("could not decode message: %s", err), - http.StatusBadRequest, "400", "Invalid request", transfer) + http.StatusBadRequest, "400", "Invalid request", nil) } - logger.Debug("Got contract message", "req", msg) + return processTransferMessage(dh, w, req, role, pid, autoProgress, msg) +} + +func processTransferMessage[T any]( + dh *dspHandlers, + w http.ResponseWriter, + req *http.Request, + role constants.DataspaceRole, + pid uuid.UUID, + autoProgress bool, + msg T, +) error { + logger := logging.Extract(req.Context()) + transfer, err := dh.store.GetTransferRW(req.Context(), pid, role) + if err != nil { + return transferError(fmt.Sprintf("%d transfer request %s not found: %s", role, pid, err), + http.StatusNotFound, "404", "Transfer request not found", nil) + } - pState := statemachine.GetTransferRequestNegotiation(dh.store, transfer, dh.provider, dh.reconciler) + pState := statemachine.GetTransferRequestNegotiation(transfer, dh.provider, dh.reconciler) nextState, err := pState.Recv(req.Context(), msg) if err != nil { @@ -200,6 +194,10 @@ func progressTransferState[T any]( } } + if err := storeRequest(req.Context(), dh.store, nextState.GetTransferRequest()); err != nil { + return err + } + if err := shared.EncodeValid(w, req, http.StatusOK, nextState.GetTransferProcess()); err != nil { logger.Error("Couldn't serve response", "err", err) } @@ -209,21 +207,33 @@ func progressTransferState[T any]( return nil } +func storeRequest( + ctx context.Context, + store persistence.StorageProvider, + request *transfer.Request, +) error { + if err := store.PutTransfer(ctx, request); err != nil { + return transferError(fmt.Sprintf("couldn't store transfer request: %s", err), + http.StatusInternalServerError, "500", "Not able to store transfer request", request) + } + return nil +} + func (dh *dspHandlers) providerTransferStartHandler(w http.ResponseWriter, req *http.Request) error { return progressTransferState[shared.TransferStartMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), false, + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), false, ) } func (dh *dspHandlers) providerTransferCompletionHandler(w http.ResponseWriter, req *http.Request) error { return progressTransferState[shared.TransferCompletionMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), true, + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), true, ) } func (dh *dspHandlers) providerTransferTerminationHandler(w http.ResponseWriter, req *http.Request) error { return progressTransferState[shared.TransferTerminationMessage]( - dh, w, req, statemachine.DataspaceProvider, req.PathValue("providerPID"), true, + dh, w, req, constants.DataspaceProvider, req.PathValue("providerPID"), true, ) } @@ -253,19 +263,19 @@ func (dh *dspHandlers) providerTransferSuspensionHandler(w http.ResponseWriter, func (dh *dspHandlers) consumerTransferStartHandler(w http.ResponseWriter, req *http.Request) error { return progressTransferState[shared.TransferStartMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), false, + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), false, ) } func (dh *dspHandlers) consumerTransferCompletionHandler(w http.ResponseWriter, req *http.Request) error { return progressTransferState[shared.TransferCompletionMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), true, + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), true, ) } func (dh *dspHandlers) consumerTransferTerminationHandler(w http.ResponseWriter, req *http.Request) error { return progressTransferState[shared.TransferTerminationMessage]( - dh, w, req, statemachine.DataspaceConsumer, req.PathValue("consumerPID"), true, + dh, w, req, constants.DataspaceConsumer, req.PathValue("consumerPID"), true, ) } diff --git a/internal/server/command.go b/internal/server/command.go index 3b4253a..7b9dc13 100644 --- a/internal/server/command.go +++ b/internal/server/command.go @@ -27,6 +27,7 @@ import ( "os" "os/signal" "path" + "strings" "sync" "time" @@ -50,51 +51,105 @@ import ( sloghttp "github.com/samber/slog-http" ) +// Viper keys. +const ( + dspAddress = "sever.dsp.address" + dspPort = "server.dsp.port" + dspExternalURL = "server.dsp.externalURL" + + providerAddress = "server.provider.address" + providerInsecure = "server.provider.insecure" + providerCACert = "server.provider.caCert" + providerClientCert = "server.provider.clientCert" + providerClientCertKey = "server.provider.clientCertKey" + + controlEnabled = "server.control.enabled" + controlAddr = "server.control.address" + controlPort = "server.control.port" + controlInsecure = "server.control.insecure" + controlCert = "server.control.cert" + controlCertKey = "server.control.certKey" + controlVerifyClientCertificates = "server.control.verifyClientCerts" + controlClientCACert = "server.control.clientCACert" + + persistenceBackend = "server.persistence.backend" + + persistenceBadgerMemory = "server.persistence.badger.memory" + persistenceBadgerDBPath = "server.persistence.badger.dbPath" +) + +// validStorageBackends are all the persistence backends we support. Right now, it's only badger. +var validStorageBackends = []string{"badger"} + // init initialises all the flags for the command. +// +//nolint:funlen // As all the flags live here, this will get rather long, TODO: split up. func init() { cfg.AddPersistentFlag( - Command, "server.dsp.address", "dsp-address", "address to listen on for dataspace operations", "0.0.0.0") + Command, dspAddress, "dsp-address", "address to listen on for dataspace operations", "0.0.0.0") cfg.AddPersistentFlag( - Command, "server.dsp.port", "dsp-port", "port to listen on for dataspace operations", 8080) + Command, dspPort, "dsp-port", "port to listen on for dataspace operations", 8080) cfg.AddPersistentFlag( Command, - "server.dsp.externalURL", + dspExternalURL, "external-url", "URL that the dataspace service is reachable by from the dataspace", "", ) cfg.AddPersistentFlag( - Command, "server.provider.address", "provider-address", "Address of the provider gRPC endpoint", "") + Command, providerAddress, "provider-address", "Address of the provider gRPC endpoint", "") cfg.AddPersistentFlag( - Command, "server.provider.insecure", "provider-insecure", "Disable TLS when connecting to provider", false) + Command, providerInsecure, "provider-insecure", "Disable TLS when connecting to provider", false) cfg.AddPersistentFlag( - Command, "server.provider.caCert", "provider-ca-cert", "CA certificate of provider cert issuer", "") + Command, providerCACert, "provider-ca-cert", "CA certificate of provider cert issuer", "") cfg.AddPersistentFlag( - Command, "server.provider.clientCert", "provider-client-cert", "Client certificate to use with provider", "") + Command, providerClientCert, "provider-client-cert", "Client certificate to use with provider", "") cfg.AddPersistentFlag( - Command, "server.provider.clientCertKey", "provider-client-cert-key", "Key for client certificate", "") + Command, providerClientCertKey, "provider-client-cert-key", "Key for client certificate", "") - cfg.AddPersistentFlag(Command, "server.control.enabled", "control-enabled", "enable gRPC control service", false) + cfg.AddPersistentFlag(Command, controlEnabled, "control-enabled", "enable gRPC control service", false) cfg.AddPersistentFlag( - Command, "server.control.address", "control-address", "address for the control service to listen on", "0.0.0.0") + Command, controlAddr, "control-address", "address for the control service to listen on", "0.0.0.0") cfg.AddPersistentFlag( - Command, "server.control.port", "control-port", "port for the control service to listen on", 8081) + Command, controlPort, "control-port", "port for the control service to listen on", 8081) cfg.AddPersistentFlag( - Command, "server.control.insecure", "control-insecure", "disable TLS for the control service", false) + Command, controlInsecure, "control-insecure", "disable TLS for the control service", false) cfg.AddPersistentFlag( - Command, "server.control.cert", "control-cert", "TLS certificate for the control service", "") + Command, controlCert, "control-cert", "TLS certificate for the control service", "") cfg.AddPersistentFlag( - Command, "server.control.certKey", "control-cert-key", "Key for control service certificate", "") + Command, controlCertKey, "control-cert-key", "Key for control service certificate", "") cfg.AddPersistentFlag( Command, - "server.control.verifyClientCerts", + controlVerifyClientCertificates, "control-verify-client-certs", "Require CA issued client certificates", false, ) cfg.AddPersistentFlag( - Command, "server.control.clientCACert", "control-client-ca-cert", "CA certificate of client cert issuer", "") + Command, controlClientCACert, "control-client-ca-cert", "CA certificate of client cert issuer", "") + + cfg.AddPersistentFlag( + Command, + persistenceBackend, + "persistence-backend", + fmt.Sprintf( + "What backend to store state in. Options: %s", + strings.Join(validStorageBackends, ","), + ), + "badger", + ) + + cfg.AddPersistentFlag( + Command, + persistenceBadgerMemory, + "badger-in-memory", + "Put badger database in memory, will not survive restarts", + false, + ) + cfg.AddPersistentFlag( + Command, persistenceBadgerDBPath, "badger-dbpath", "Path to store the badger database", "", + ) } // Command validates the configuration and then runs the server. @@ -104,46 +159,46 @@ var Command = &cobra.Command{ Long: `Starts the RUN-DSP connector, which then connects to the provider and will start serving dataspace requests`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - _, err := url.Parse(viper.GetString("server.dsp.externalURL")) + _, err := url.Parse(viper.GetString(dspExternalURL)) if err != nil { return fmt.Errorf("Invalid external URL: %w", err) } - err = cfg.CheckListenPort(viper.GetString("server.dsp.address"), viper.GetInt("server.dsp.port")) + err = cfg.CheckListenPort(viper.GetString(dspAddress), viper.GetInt(dspPort)) if err != nil { return err } - err = cfg.CheckConnectAddr(viper.GetString("server.provider.address")) + err = cfg.CheckConnectAddr(viper.GetString(providerAddress)) if err != nil { return err } - if !viper.GetBool("server.provider.insecure") { + if !viper.GetBool(providerInsecure) { err = cfg.CheckFilesExist( - viper.GetString("server.provider.caCert"), - viper.GetString("server.provider.clientCert"), - viper.GetString("server.provider.clientCertKey"), + viper.GetString(providerCACert), + viper.GetString(providerClientCert), + viper.GetString(providerClientCertKey), ) if err != nil { return err } } - if viper.GetBool("server.control.enabled") { - err = cfg.CheckListenPort(viper.GetString("server.control.address"), viper.GetInt("server.control.port")) + if viper.GetBool(controlEnabled) { + err = cfg.CheckListenPort(viper.GetString(controlAddr), viper.GetInt(controlPort)) if err != nil { return err } - if !viper.GetBool("server.control.insecure") { + if !viper.GetBool(controlInsecure) { err = cfg.CheckFilesExist( - viper.GetString("server.control.cert"), - viper.GetString("server.control.certKey"), + viper.GetString(controlCert), + viper.GetString(controlCertKey), ) if err != nil { return err } - if viper.GetBool("server.control.verifyClientCerts") { - err = cfg.CheckFilesExist(viper.GetString("server.control.clientCACert")) + if viper.GetBool(controlVerifyClientCertificates) { + err = cfg.CheckFilesExist(viper.GetString(controlClientCACert)) if err != nil { return err } @@ -152,30 +207,41 @@ var Command = &cobra.Command{ } + switch viper.GetString(persistenceBackend) { + case "badger": + mem := viper.GetBool(persistenceBadgerMemory) + path := viper.GetString(persistenceBadgerDBPath) + if mem && path != "" { + return fmt.Errorf("in-memory database is mutually exclusive with a database path") + } + default: + return fmt.Errorf("invalid persistence backend") + } + return nil }, RunE: func(cmd *cobra.Command, args []string) error { - u, err := url.Parse(viper.GetString("server.dsp.externalURL")) + u, err := url.Parse(viper.GetString(dspExternalURL)) if err != nil { panic(err.Error()) } c := command{ - ListenAddr: viper.GetString("server.dsp.address"), - Port: viper.GetInt("server.dsp.port"), + ListenAddr: viper.GetString(dspAddress), + Port: viper.GetInt(dspPort), ExternalURL: u, - ProviderAddress: viper.GetString("server.provider.address"), - ProviderInsecure: viper.GetBool("server.provider.insecure"), - ProviderCACert: viper.GetString("server.provider.caCert"), - ProviderClientCert: viper.GetString("server.provider.clientCert"), - ProviderClientCertKey: viper.GetString("server.provider.clientCert"), - ControlEnabled: viper.GetBool("server.control.enabled"), - ControlListenAddr: viper.GetString("server.control.address"), - ControlPort: viper.GetInt("server.control.port"), - ControlInsecure: viper.GetBool("server.control.insecure"), - ControlCert: viper.GetString("server.control.cert"), - ControlCertKey: viper.GetString("server.control.certKey"), - ControlVerifyClientCertificates: viper.GetBool("server.control.verifyClientCerts"), - ControlClientCACert: viper.GetString("server.control.clientCACert"), + ProviderAddress: viper.GetString(providerAddress), + ProviderInsecure: viper.GetBool(providerInsecure), + ProviderCACert: viper.GetString(providerCACert), + ProviderClientCert: viper.GetString(providerClientCert), + ProviderClientCertKey: viper.GetString(providerClientCertKey), + ControlEnabled: viper.GetBool(controlEnabled), + ControlListenAddr: viper.GetString(controlAddr), + ControlPort: viper.GetInt(controlPort), + ControlInsecure: viper.GetBool(controlInsecure), + ControlCert: viper.GetString(controlCert), + ControlCertKey: viper.GetString(controlCertKey), + ControlVerifyClientCertificates: viper.GetBool(controlVerifyClientCertificates), + ControlClientCACert: viper.GetString(controlClientCACert), } ctx, ok := viper.Get("initCTX").(context.Context) if !ok { @@ -191,7 +257,7 @@ type command struct { ExternalURL *url.URL - // GRPC settings for the provider + // GRPC settings for the provider. ProviderAddress string ProviderInsecure bool ProviderCACert string @@ -207,6 +273,13 @@ type command struct { ControlCertKey string ControlVerifyClientCertificates bool ControlClientCACert string + + // Persistence settings + PersistenceBackend string + + // Badger backend settings. + BadgerMemoryDB bool + BadgerDBPath string } // Run starts the server. @@ -215,25 +288,25 @@ func (c *command) Run(ctx context.Context) error { logger := logging.Extract(ctx) ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, os.Kill) defer cancel() - logger.Info("Starting server", "listenAddr", c.ListenAddr, "port", c.Port, "externalURL", c.ExternalURL, ) - provider, conn, err := c.getProvider(ctx) if err != nil { return err } defer conn.Close() - pingResponse, err := provider.Ping(ctx, &providerv1.PingRequest{}) if err != nil { return fmt.Errorf("could not ping provider: %w", err) } + store, err := c.getStorageProvider(ctx) + if err != nil { + return err + } - store := statemachine.NewMemoryArchiver() httpClient := &shared.HTTPRequester{} reconciler := statemachine.NewReconciler(ctx, httpClient, store) reconciler.Run() diff --git a/internal/server/persistence.go b/internal/server/persistence.go new file mode 100644 index 0000000..0d32688 --- /dev/null +++ b/internal/server/persistence.go @@ -0,0 +1,32 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "fmt" + + "github.com/go-dataspace/run-dsp/dsp/persistence" + "github.com/go-dataspace/run-dsp/dsp/persistence/badger" +) + +func (c *command) getStorageProvider(ctx context.Context) (persistence.StorageProvider, error) { + switch c.PersistenceBackend { + case "badger": + return badger.New(ctx, c.BadgerMemoryDB, c.BadgerDBPath) + default: + return nil, fmt.Errorf("Invalid backend") + } +}