diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index de94759..8633582 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/golangci/golangci-lint - rev: v1.58.0 + rev: v1.58.2 hooks: - id: golangci-lint - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/dsp/catalog_handlers.go b/dsp/catalog_handlers.go index 8fd7457..2da1e4c 100644 --- a/dsp/catalog_handlers.go +++ b/dsp/catalog_handlers.go @@ -128,7 +128,7 @@ func (ch *dspHandlers) datasetRequestHandler(w http.ResponseWriter, req *http.Re func processProviderDataset(pds *providerv1.Dataset, service shared.DataService) shared.Dataset { ds := shared.Dataset{ Resource: shared.Resource{ - ID: fmt.Sprintf("urn:uuid:%s", pds.GetId()), + ID: shared.IDToURN(pds.GetId()), Type: "dcat:Dataset", Title: pds.GetTitle(), Issued: pds.GetIssued().AsTime().Format(time.RFC3339), diff --git a/dsp/contract_handlers.go b/dsp/contract_handlers.go index 3ad6ee1..e5ad239 100644 --- a/dsp/contract_handlers.go +++ b/dsp/contract_handlers.go @@ -378,10 +378,7 @@ func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWrit ctx, logger := logging.InjectLabels(req.Context(), "handler", "triggerConsumerContractRequestHandler") req = req.WithContext(ctx) - datasetID, err := uuid.Parse(req.PathValue("datasetID")) - if err != nil { - return fmt.Errorf("Dataset ID is not a UUID") - } + datasetID := shared.IDToURN(req.PathValue("datasetID")) logger.Debug("Got trigger request to start contract negotiation") selfURL, err := url.Parse(dh.selfURL.String()) @@ -401,7 +398,7 @@ func (dh *dspHandlers) triggerConsumerContractRequestHandler(w http.ResponseWrit ID: uuid.New().URN(), }, Type: "odrl:Offer", - Target: datasetID.URN(), + Target: datasetID, }, }, dh.selfURL, diff --git a/dsp/shared/id_helpers.go b/dsp/shared/id_helpers.go new file mode 100644 index 0000000..a2b21bd --- /dev/null +++ b/dsp/shared/id_helpers.go @@ -0,0 +1,66 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "fmt" + "strings" + + "github.com/go-dataspace/run-dsp/oid" + "github.com/google/uuid" +) + +// IDtoURN generates the URN for the ID. Right now we only support preformatted URNs, +// UUIDs, and OIDs. +func IDToURN(s string) string { + // If the string starts with urn:, we assume the string is already an URN. + if strings.HasPrefix(strings.ToLower(s), "urn:") { + return s + } + + // Check if we are dealing with a UUID. + if u, err := uuid.Parse(s); err == nil { + return u.URN() + } + + // If not, maybe an OID? + if o, err := oid.Parse(s); err == nil { + return o.URN() + } + + // If still unknown, return an unknown URN. + return fmt.Sprintf("urn:unknown:%s", s) +} + +// URNtoRawID strips the URN part and returns the ID without any metadata. +// This function only works on URNs with 3 parts like the uuid or oid URNs. +// This function returns an error when we can't properly split it. +// TODO: Verify that all providers support URNs so that we can remove this function. +func URNtoRawID(s string) (string, error) { + // If the ID doesn't start with "urn:", we just return it. + if !strings.HasPrefix(strings.ToLower(s), "urn:") { + return s, nil + } + + parts := strings.SplitN(s, ":", 3) + // If we don't get the right amount of parts, we return the full string, despite it most likely + // won't be the result we want. This should be exceedingly rare, and we might want to panic here + // instead to make the problem very obvious. + if len(parts) != 3 { + return "", fmt.Errorf("malformed URN: %s", s) + } + + return parts[2], nil +} diff --git a/dsp/statemachine/contract_transitions.go b/dsp/statemachine/contract_transitions.go index 5c2a30f..1be4be8 100644 --- a/dsp/statemachine/contract_transitions.go +++ b/dsp/statemachine/contract_transitions.go @@ -91,14 +91,14 @@ func (cn *ContractNegotiationInitial) Recv( ) logger.Debug("Received message") - target, err := uuid.Parse(cn.GetOffer().Target) + target, err := shared.URNtoRawID(cn.GetOffer().Target) if err != nil { - logger.Error("target is not a valid UUID", "err", err) - return ctx, nil, fmt.Errorf("target is not a valid UUID: %w", err) + logger.Error("can't parse URN", "err", err) + return ctx, nil, fmt.Errorf("can't parse URN: %w", err) } // This is the initial request, we can assume all data is freshly made based on the request. _, err = cn.GetProvider().GetDataset(ctx, &providerv1.GetDatasetRequest{ - DatasetId: target.String(), + DatasetId: target, }) if err != nil { logger.Error("target dataset not found", "err", err) @@ -159,13 +159,13 @@ func (cn *ContractNegotiationInitial) Send(ctx context.Context) (func(), error) } return sendContractRequest(ctx, cn.GetReconciler(), cn.GetContract()) case cn.GetProviderPID() != emptyUUID: - u, err := uuid.Parse(cn.GetOffer().Target) + targetID, err := shared.URNtoRawID(cn.GetOffer().Target) if err != nil { - logger.Error("invalid UUID", "err", err) - return func() {}, fmt.Errorf("invalid UUID `%s`: %w", cn.GetOffer().Target, err) + logger.Error("invalid URN", "err", err) + return func() {}, fmt.Errorf("invalid URN `%s`: %w", cn.GetOffer().Target, err) } _, err = cn.GetProvider().GetDataset(ctx, &providerv1.GetDatasetRequest{ - DatasetId: u.String(), + DatasetId: targetID, }) if err != nil { logger.Error("Dataset not found", "err", err) diff --git a/dsp/statemachine/transfer_request.go b/dsp/statemachine/transfer_request.go index b557eeb..c14b740 100644 --- a/dsp/statemachine/transfer_request.go +++ b/dsp/statemachine/transfer_request.go @@ -59,7 +59,7 @@ type TransferRequest struct { providerPID uuid.UUID consumerPID uuid.UUID agreementID uuid.UUID - target uuid.UUID + target string format string callback *url.URL self *url.URL @@ -71,7 +71,7 @@ type TransferRequest struct { func (tr *TransferRequest) GetProviderPID() uuid.UUID { return tr.providerPID } func (tr *TransferRequest) GetConsumerPID() uuid.UUID { return tr.consumerPID } func (tr *TransferRequest) GetAgreementID() uuid.UUID { return tr.agreementID } -func (tr *TransferRequest) GetTarget() uuid.UUID { return tr.target } +func (tr *TransferRequest) GetTarget() string { return tr.target } func (tr *TransferRequest) GetFormat() string { return tr.format } func (tr *TransferRequest) GetCallback() *url.URL { return tr.callback } func (tr *TransferRequest) GetSelf() *url.URL { return tr.self } diff --git a/dsp/statemachine/transfer_request_transitions.go b/dsp/statemachine/transfer_request_transitions.go index 526f0ad..17582f7 100644 --- a/dsp/statemachine/transfer_request_transitions.go +++ b/dsp/statemachine/transfer_request_transitions.go @@ -32,7 +32,7 @@ type TransferRequester interface { GetProviderPID() uuid.UUID GetConsumerPID() uuid.UUID GetAgreementID() uuid.UUID - GetTarget() uuid.UUID + GetTarget() string GetFormat() string GetCallback() *url.URL GetSelf() *url.URL @@ -65,7 +65,7 @@ func (tr *TransferRequestNegotiationInitial) Recv( switch t := message.(type) { case shared.TransferRequestMessage: _, err := tr.GetProvider().GetDataset(ctx, &providerv1.GetDatasetRequest{ - DatasetId: tr.GetTarget().String(), + DatasetId: tr.GetTarget(), }) if err != nil { return nil, fmt.Errorf("could not find target: %w", err) @@ -116,7 +116,7 @@ func (tr *TransferRequestNegotiationRequested) Send(ctx context.Context) (func() switch tr.GetTransferDirection() { case DirectionPull: resp, err := tr.GetProvider().PublishDataset(ctx, &providerv1.PublishDatasetRequest{ - DatasetId: tr.GetTarget().String(), + DatasetId: tr.GetTarget(), PublishId: tr.GetProviderPID().String(), }) if err != nil { @@ -212,11 +212,15 @@ func NewTransferRequest( if err != nil { return nil, fmt.Errorf("no agreement found") } + targetID, err := shared.URNtoRawID(agreement.Target) + if err != nil { + return nil, fmt.Errorf("couldn't parse target URN: %w", err) + } traReq := &TransferRequest{ state: state, consumerPID: consumerPID, agreementID: agreementID, - target: uuid.MustParse(agreement.Target), + target: targetID, format: format, callback: callback, self: self, diff --git a/oid/oid.go b/oid/oid.go new file mode 100644 index 0000000..5465db1 --- /dev/null +++ b/oid/oid.go @@ -0,0 +1,85 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package OID contains tools to work with OIDs. +package oid + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// urnPrefix contains the prefix of an OID formatted as an URN. +const urnPrefix = "urn:oid:" + +// OID represents an OID as a slice of ints. +type OID []int64 + +// validOID is a regex that checks if the string starts, and ends, with a number, and only +// contains numbers and periods in between. +var validOID = regexp.MustCompile(`^\d[\d\.]+\d$`) + +// Parse parses an OID string formatted as '1.2.3.4.5...' and returns an OID instance. It returns +// an error if the OID string can't be parsed. +func Parse(s string) (OID, error) { + if len(s) == 0 { + return OID{}, fmt.Errorf("can't parse empty string") + } + s = strings.ToLower(s) + if strings.HasPrefix(s, urnPrefix) { + s = strings.Replace(s, urnPrefix, "", 1) + } + if !validOID.MatchString(s) { + return OID{}, fmt.Errorf("invalid OID: %s", s) + } + parts := strings.Split(s, ".") + oid := make(OID, len(parts)) + for i, n := range parts { + var err error + oid[i], err = strconv.ParseInt(n, 10, 64) + if err != nil { + return oid, fmt.Errorf("could not parse OID string: %w", err) + } + } + return oid, nil +} + +// MustParse parses the OID string, but panics on error. +func MustParse(s string) OID { + oid, err := Parse(s) + if err != nil { + panic(err) + } + return oid +} + +// String returns the string representation of the OID. +func (o OID) String() string { + if o == nil { + o = OID{} + } + parts := make([]string, len(o)) + for i, n := range o { + parts[i] = strconv.FormatInt(n, 10) + } + return strings.Join(parts, ".") +} + +// URN returns the URN of the OID. +func (o OID) URN() string { + s := o.String() + return fmt.Sprintf("%s%s", urnPrefix, s) +} diff --git a/oid/oid_test.go b/oid/oid_test.go new file mode 100644 index 0000000..f40360f --- /dev/null +++ b/oid/oid_test.go @@ -0,0 +1,136 @@ +// Copyright 2024 go-dataspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package OID contains tools to work with OIDs. +package oid_test + +import ( + "testing" + + "github.com/go-dataspace/run-dsp/oid" + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want oid.OID + assertion assert.ErrorAssertionFunc + }{ + { + name: "Parses normal OID.", + args: args{ + s: "1.3.6.1.4.1.311.21.20", + }, + want: oid.OID{1, 3, 6, 1, 4, 1, 311, 21, 20}, + assertion: assert.NoError, + }, + { + name: "Errors on wrong character.", + args: args{ + s: "1.3.6.1a.4.1.311.21.20", + }, + want: oid.OID{}, + assertion: assert.Error, + }, + { + name: "Errors when not starting with number.", + args: args{ + s: ".3.6.1.4.1.311.21.20", + }, + want: oid.OID{}, + assertion: assert.Error, + }, + { + name: "Errors when not ending with number.", + args: args{ + s: "1.3.6.1.4.1.311.21.", + }, + want: oid.OID{}, + assertion: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := oid.Parse(tt.args.s) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMustParse(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + }{ + { + name: "Panics on broken uuid", + args: args{ + s: "1.3a.6.1.4.1.311.21.20", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { _ = oid.MustParse(tt.args.s) }) + }) + } +} + +func TestOID_String(t *testing.T) { + tests := []struct { + name string + o oid.OID + want string + }{ + { + name: "Check formatting on normal OID", + o: oid.OID{1, 3, 6, 1, 4, 1, 311, 21, 20}, + want: "1.3.6.1.4.1.311.21.20", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.o.String()) + }) + } +} + +func TestOID_URN(t *testing.T) { + tests := []struct { + name string + o oid.OID + want string + }{ + { + name: "Check formatting on normal OID", + o: oid.OID{1, 3, 6, 1, 4, 1, 311, 21, 20}, + want: "urn:oid:1.3.6.1.4.1.311.21.20", + }, + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.o.URN()) + }) + } +}