diff --git a/cmd/mailroom/main.go b/cmd/mailroom/main.go index 791ba5f7d..a7411410a 100644 --- a/cmd/mailroom/main.go +++ b/cmd/mailroom/main.go @@ -28,6 +28,7 @@ import ( _ "github.com/nyaruka/mailroom/core/tasks/timeouts" _ "github.com/nyaruka/mailroom/services/external/omie" _ "github.com/nyaruka/mailroom/services/external/openai/chatgpt" + _ "github.com/nyaruka/mailroom/services/external/weni" _ "github.com/nyaruka/mailroom/services/ivr/twiml" _ "github.com/nyaruka/mailroom/services/ivr/vonage" _ "github.com/nyaruka/mailroom/services/tickets/intern" diff --git a/core/goflow/engine.go b/core/goflow/engine.go index 2ce4ab8a6..9e73d3db9 100644 --- a/core/goflow/engine.go +++ b/core/goflow/engine.go @@ -20,6 +20,7 @@ var classificationFactory func(*runtime.Config) engine.ClassificationServiceFact var ticketFactory func(*runtime.Config) engine.TicketServiceFactory var airtimeFactory func(*runtime.Config) engine.AirtimeServiceFactory var externalServiceFactory func(*runtime.Config) engine.ExternalServiceServiceFactory +var msgCatalogFactory func(*runtime.Config) engine.MsgCatalogServiceFactory // RegisterEmailServiceFactory can be used by outside callers to register a email factory // for use by the engine @@ -49,6 +50,10 @@ func RegisterExternalServiceServiceFactory(f func(*runtime.Config) engine.Extern externalServiceFactory = f } +func RegisterMsgCatalogServiceFactory(f func(*runtime.Config) engine.MsgCatalogServiceFactory) { + msgCatalogFactory = f +} + // Engine returns the global engine instance for use with real sessions func Engine(c *runtime.Config) flows.Engine { engInit.Do(func() { @@ -65,6 +70,7 @@ func Engine(c *runtime.Config) flows.Engine { WithEmailServiceFactory(emailFactory(c)). WithTicketServiceFactory(ticketFactory(c)). WithExternalServiceServiceFactory(externalServiceFactory((c))). + WithMsgCatalogServiceFactory(msgCatalogFactory((c))). // msg catalog WithAirtimeServiceFactory(airtimeFactory(c)). WithMaxStepsPerSprint(c.MaxStepsPerSprint). WithMaxResumesPerSession(c.MaxResumesPerSession). @@ -88,6 +94,7 @@ func Simulator(c *runtime.Config) flows.Engine { WithWebhookServiceFactory(webhooks.NewServiceFactory(httpClient, nil, httpAccess, webhookHeaders, c.WebhooksMaxBodyBytes)). WithClassificationServiceFactory(classificationFactory(c)). // simulated sessions do real classification WithExternalServiceServiceFactory(externalServiceFactory((c))). // and real external services + WithMsgCatalogServiceFactory(msgCatalogFactory((c))). // msg catalog WithEmailServiceFactory(simulatorEmailServiceFactory). // but faked emails WithTicketServiceFactory(simulatorTicketServiceFactory). // and faked tickets WithAirtimeServiceFactory(simulatorAirtimeServiceFactory). // and faked airtime transfers diff --git a/core/handlers/msg_catalog_created.go b/core/handlers/msg_catalog_created.go new file mode 100644 index 000000000..9c4163361 --- /dev/null +++ b/core/handlers/msg_catalog_created.go @@ -0,0 +1,127 @@ +package handlers + +import ( + "context" + "fmt" + + "github.com/nyaruka/gocommon/urns" + "github.com/nyaruka/goflow/envs" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/events" + "github.com/nyaruka/mailroom/core/hooks" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/runtime" + + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +func init() { + models.RegisterEventPreWriteHandler(events.TypeMsgCatalogCreated, handlePreMsgCatalogCreated) + models.RegisterEventHandler(events.TypeMsgCatalogCreated, handleMsgCatalogCreated) +} + +// handlePreMsgCatalogCreated clears our timeout on our session so that courier can send it when the message is sent, that will be set by courier when sent +func handlePreMsgCatalogCreated(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa *models.OrgAssets, scene *models.Scene, e flows.Event) error { + event := e.(*events.MsgCatalogCreatedEvent) + + // we only clear timeouts on messaging flows + if scene.Session().SessionType() != models.FlowTypeMessaging { + return nil + } + + // get our channel + var channel *models.Channel + + if event.Msg.Channel() != nil { + channel = oa.ChannelByUUID(event.Msg.Channel().UUID) + if channel == nil { + return errors.Errorf("unable to load channel with uuid: %s", event.Msg.Channel().UUID) + } + } + + // no channel? this is a no-op + if channel == nil { + return nil + } + + // android channels get normal timeouts + if channel.Type() == models.ChannelTypeAndroid { + return nil + } + + // everybody else gets their timeout cleared, will be set by courier + scene.Session().ClearTimeoutOn() + + return nil +} + +// handleMsgCreated creates the db msg for the passed in event +func handleMsgCatalogCreated(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa *models.OrgAssets, scene *models.Scene, e flows.Event) error { + event := e.(*events.MsgCatalogCreatedEvent) + + // must be in a session + if scene.Session() == nil { + return errors.Errorf("cannot handle msg created event without session") + } + + logrus.WithFields(logrus.Fields{ + "contact_uuid": scene.ContactUUID(), + "session_id": scene.SessionID(), + "text": event.Msg.Text(), + "header": event.Msg.Header(), + "products": event.Msg.Products(), + "urn": event.Msg.URN(), + "action": event.Msg.Action(), + }).Debug("msg created event") + + // messages in messaging flows must have urn id set on them, if not, go look it up + if scene.Session().SessionType() == models.FlowTypeMessaging && event.Msg.URN() != urns.NilURN { + urn := event.Msg.URN() + if models.GetURNInt(urn, "id") == 0 { + urn, err := models.GetOrCreateURN(ctx, tx, oa, scene.ContactID(), event.Msg.URN()) + if err != nil { + return errors.Wrapf(err, "unable to get or create URN: %s", event.Msg.URN()) + } + // update our Msg with our full URN + event.Msg.SetURN(urn) + } + } + + // get our channel + var channel *models.Channel + if event.Msg.Channel() != nil { + channel = oa.ChannelByUUID(event.Msg.Channel().UUID) + if channel == nil { + return errors.Errorf("unable to load channel with uuid: %s", event.Msg.Channel().UUID) + } else { + if fmt.Sprint(channel.Type()) == "WAC" || fmt.Sprint(channel.Type()) == "WA" { + country := envs.DeriveCountryFromTel("+" + event.Msg.URN().Path()) + locale := envs.NewLocale(scene.Contact().Language(), country) + languageCode := locale.ToBCP47() + + if _, valid := validLanguageCodes[languageCode]; !valid { + languageCode = "" + } + + event.Msg.TextLanguage = envs.Language(languageCode) + } + } + } + + msg, err := models.NewOutgoingFlowMsgCatalog(rt, oa.Org(), channel, scene.Session(), event.Msg, event.CreatedOn()) + if err != nil { + return errors.Wrapf(err, "error creating outgoing message to %s", event.Msg.URN()) + } + + // register to have this message committed + scene.AppendToEventPreCommitHook(hooks.CommitMessagesHook, msg) + + // don't send messages for surveyor flows + if scene.Session().SessionType() != models.FlowTypeSurveyor { + scene.AppendToEventPostCommitHook(hooks.SendMessagesHook, msg) + } + + return nil +} diff --git a/core/handlers/msg_catalog_created_test.go b/core/handlers/msg_catalog_created_test.go new file mode 100644 index 000000000..bfd2ecaa4 --- /dev/null +++ b/core/handlers/msg_catalog_created_test.go @@ -0,0 +1,100 @@ +package handlers_test + +import ( + "fmt" + "testing" + + "github.com/nyaruka/gocommon/urns" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/actions" + "github.com/nyaruka/mailroom/core/handlers" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + + "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" +) + +func TestMsgCatalogCreated(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetAll) + + // add a URN for cathy so we can test all urn sends + testdata.InsertContactURN(db, testdata.Org1, testdata.Cathy, urns.URN("tel:+12065551212"), 10) + + // delete all URNs for bob + db.MustExec(`DELETE FROM contacts_contacturn WHERE contact_id = $1`, testdata.Bob.ID) + + msg1 := testdata.InsertIncomingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "start", models.MsgStatusHandled) + + tcs := []handlers.TestCase{ + { + Actions: handlers.ContactActionMap{ + testdata.Cathy: []flows.Action{ + actions.NewSendMsgCatalog( + handlers.NewActionUUID(), + "", "Some products", "", "View Products", "", + []map[string]string{ + {"product_retailer_id": "9f526c6f-b2cb-4457-8048-a7f1dc101e50"}, + {"product_retailer_id": "eb2305cc-bf39-43ad-a069-bbbfb6401acc"}, + }, + false, + true, + ), + }, + testdata.George: []flows.Action{ + actions.NewSendMsgCatalog( + handlers.NewActionUUID(), + "Select The Service", "", "", "View Products", "", + []map[string]string{ + {"product_retailer_id": "cbd9ba07-7156-406e-8006-5b697d18d091"}, + {"product_retailer_id": "63157bd2-6f94-4dbb-b394-ea4eb07ce156"}, + }, + false, + true, + ), + }, + testdata.Bob: []flows.Action{ + actions.NewSendMsgCatalog(handlers.NewActionUUID(), "No URNs", "", "", "View Products", "i want a water bottle", nil, false, false), + }, + }, + Msgs: handlers.ContactMsgMap{ + testdata.Cathy: msg1, + }, + SQLAssertions: []handlers.SQLAssertion{ + { + SQL: "SELECT COUNT(*) FROM msgs_msg WHERE contact_id = $1 AND metadata = $2 AND high_priority = TRUE", + Args: []interface{}{testdata.Cathy.ID, `{"action":"View Products","body":"Some products","products":["9f526c6f-b2cb-4457-8048-a7f1dc101e50","eb2305cc-bf39-43ad-a069-bbbfb6401acc"]}`}, + Count: 2, + }, + { + SQL: "SELECT COUNT(*) FROM msgs_msg WHERE contact_id = $1 AND status = 'Q' AND high_priority = FALSE", + Args: []interface{}{testdata.George.ID}, + Count: 1, + }, + { + SQL: "SELECT COUNT(*) FROM msgs_msg WHERE contact_id=$1 AND STATUS = 'F' AND failed_reason = 'D';", + Args: []interface{}{testdata.Bob.ID}, + Count: 1, + }, + }, + }, + } + + handlers.RunTestCases(t, ctx, rt, tcs) + + rc := rp.Get() + defer rc.Close() + + // Cathy should have 1 batch of queued messages at high priority + count, err := redis.Int(rc.Do("zcard", fmt.Sprintf("msgs:%s|10/1", testdata.TwilioChannel.UUID))) + assert.NoError(t, err) + assert.Equal(t, 1, count) + + // One bulk for George + count, err = redis.Int(rc.Do("zcard", fmt.Sprintf("msgs:%s|10/0", testdata.TwilioChannel.UUID))) + assert.NoError(t, err) + assert.Equal(t, 1, count) +} diff --git a/core/models/assets.go b/core/models/assets.go index ebe2a32f7..218d252fd 100644 --- a/core/models/assets.go +++ b/core/models/assets.go @@ -79,6 +79,10 @@ type OrgAssets struct { externalServices []assets.ExternalService externalServicesByID map[ExternalServiceID]*ExternalService externalServicesByUUID map[assets.ExternalServiceUUID]*ExternalService + + msgCatalogs []assets.MsgCatalog + msgCatalogsByID map[CatalogID]*MsgCatalog + msgCatalogsByUUID map[assets.ChannelUUID]*MsgCatalog } var ErrNotFound = errors.New("not found") @@ -381,6 +385,24 @@ func NewOrgAssets(ctx context.Context, rt *runtime.Runtime, orgID OrgID, prev *O oa.externalServicesByUUID = prev.externalServicesByUUID } + if prev == nil || refresh&RefreshMsgCatalogs > 0 { + oa.msgCatalogs, err = loadCatalog(ctx, db, orgID) + if err != nil { + return nil, errors.Wrapf(err, "error loading catalogs for org %d", orgID) + } + oa.msgCatalogsByID = make(map[CatalogID]*MsgCatalog) + oa.msgCatalogsByUUID = make(map[assets.ChannelUUID]*MsgCatalog) + + for _, a := range oa.msgCatalogs { + oa.msgCatalogsByID[a.(*MsgCatalog).c.ID] = a.(*MsgCatalog) + oa.msgCatalogsByUUID[a.(*MsgCatalog).c.ChannelUUID] = a.(*MsgCatalog) + } + } else { + oa.msgCatalogs = prev.msgCatalogs + oa.msgCatalogsByID = prev.msgCatalogsByID + oa.msgCatalogsByUUID = prev.msgCatalogsByUUID + } + // intialize our session assets oa.sessionAssets, err = engine.NewSessionAssets(oa.Env(), oa, goflow.MigrationConfig(rt.Config)) if err != nil { @@ -414,6 +436,7 @@ const ( RefreshTopics = Refresh(1 << 15) RefreshUsers = Refresh(1 << 16) RefreshExternalServices = Refresh(1 << 17) + RefreshMsgCatalogs = Refresh(1 << 18) ) // GetOrgAssets creates or gets org assets for the passed in org @@ -706,3 +729,7 @@ func (a *OrgAssets) ExternalServiceByID(id ExternalServiceID) *ExternalService { func (a *OrgAssets) ExternalServiceByUUID(uuid assets.ExternalServiceUUID) *ExternalService { return a.externalServicesByUUID[uuid] } + +func (a *OrgAssets) MsgCatalogs() ([]assets.MsgCatalog, error) { + return a.msgCatalogs, nil +} diff --git a/core/models/catalog_products.go b/core/models/catalog_products.go new file mode 100644 index 000000000..8190d94e6 --- /dev/null +++ b/core/models/catalog_products.go @@ -0,0 +1,203 @@ +package models + +import ( + "context" + "database/sql" + "database/sql/driver" + "net/http" + "time" + + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/engine" + "github.com/nyaruka/mailroom/core/goflow" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/mailroom/utils/dbutil" + "github.com/nyaruka/null" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type CatalogID null.Int + +func (i CatalogID) MarshalJSON() ([]byte, error) { + return null.Int(i).MarshalJSON() +} + +func (i *CatalogID) UnmarshalJSON(b []byte) error { + return null.UnmarshalInt(b, (*null.Int)(i)) +} + +func (i CatalogID) Value() (driver.Value, error) { + return null.Int(i).Value() +} + +func (i *CatalogID) Scan(value interface{}) error { + return null.ScanInt(value, (*null.Int)(i)) +} + +// MsgCatalog represents a product catalog from Whatsapp channels. +type MsgCatalog struct { + c struct { + ID CatalogID `json:"id"` + UUID assets.MsgCatalogUUID `json:"uuid"` + FacebookCatalogID string `json:"facebook_catalog_id"` + Name string `json:"name"` + CreatedOn time.Time `json:"created_on"` + ModifiedOn time.Time `json:"modified_on"` + IsActive bool `json:"is_active"` + ChannelID ChannelID `json:"channel_id"` + OrgID OrgID `json:"org_id"` + ChannelUUID assets.ChannelUUID `json:"channel_uuid"` + Type string `json:"type"` + } +} + +func (c *MsgCatalog) ID() CatalogID { return c.c.ID } +func (c *MsgCatalog) UUID() assets.MsgCatalogUUID { return c.c.UUID } +func (c *MsgCatalog) FacebookCatalogID() string { return c.c.FacebookCatalogID } +func (c *MsgCatalog) Name() string { return c.c.Name } +func (c *MsgCatalog) CreatedOn() time.Time { return c.c.CreatedOn } +func (c *MsgCatalog) ModifiedOn() time.Time { return c.c.ModifiedOn } +func (c *MsgCatalog) IsActive() bool { return c.c.IsActive } +func (c *MsgCatalog) ChannelID() ChannelID { return c.c.ChannelID } +func (c *MsgCatalog) OrgID() OrgID { return c.c.OrgID } +func (c *MsgCatalog) Type() string { return c.c.Type } +func (c *MsgCatalog) ChannelUUID() assets.ChannelUUID { return c.c.ChannelUUID } + +func init() { + goflow.RegisterMsgCatalogServiceFactory(msgCatalogServiceFactory) +} + +func msgCatalogServiceFactory(c *runtime.Config) engine.MsgCatalogServiceFactory { + return func(session flows.Session, msgCatalog *flows.MsgCatalog) (flows.MsgCatalogService, error) { + return msgCatalog.Asset().(*MsgCatalog).AsService(c, msgCatalog) + } +} + +func (e *MsgCatalog) AsService(cfg *runtime.Config, msgCatalog *flows.MsgCatalog) (MsgCatalogService, error) { + httpClient, httpRetries, _ := goflow.HTTP(cfg) + + initFunc := msgCatalogServices["msg_catalog"] + if initFunc != nil { + return initFunc(cfg, httpClient, httpRetries, msgCatalog, nil) + } + + return nil, errors.Errorf("unrecognized product catalog %s", e.Name()) +} + +type MsgCatalogServiceFunc func(*runtime.Config, *http.Client, *httpx.RetryConfig, *flows.MsgCatalog, map[string]string) (MsgCatalogService, error) + +var msgCatalogServices = map[string]MsgCatalogServiceFunc{} + +type MsgCatalogService interface { + flows.MsgCatalogService +} + +func RegisterMsgCatalogService(name string, initFunc MsgCatalogServiceFunc) { + msgCatalogServices[name] = initFunc +} + +const getActiveCatalogSQL = ` +SELECT ROW_TO_JSON(r) FROM (SELECT + c.id as id, + c.uuid as uuid, + c.facebook_catalog_id as facebook_catalog_id, + c.name as name, + c.created_on as created_on, + c.modified_on as modified_on, + c.is_active as is_active, + c.channel_id as channel_id, + c.org_id as org_id +FROM + public.wpp_products_catalog c +WHERE + channel_id = $1 AND is_active = true +) r; +` + +// GetActiveCatalogFromChannel returns the active catalog from the given channel +func GetActiveCatalogFromChannel(ctx context.Context, db sqlx.DB, channelID ChannelID) (*MsgCatalog, error) { + var catalog MsgCatalog + + rows, err := db.QueryxContext(ctx, getActiveCatalogSQL, channelID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, errors.Wrapf(err, "error getting active catalog for channelID: %d", channelID) + } + defer rows.Close() + + for rows.Next() { + err = dbutil.ReadJSONRow(rows, &catalog.c) + if err != nil { + return nil, err + } + } + + return &catalog, nil +} + +func loadCatalog(ctx context.Context, db *sqlx.DB, orgID OrgID) ([]assets.MsgCatalog, error) { + start := time.Now() + + rows, err := db.Queryx(selectOrgCatalogSQL, orgID) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrapf(err, "error querying catalog for org: %d", orgID) + } + defer rows.Close() + + catalog := make([]assets.MsgCatalog, 0) + for rows.Next() { + msgCatalog := &MsgCatalog{} + err := dbutil.ReadJSONRow(rows, &msgCatalog.c) + if err != nil { + return nil, errors.Wrapf(err, "error unmarshalling catalog") + } + channelUUID, err := ChannelUUIDForChannelID(ctx, db, msgCatalog.ChannelID()) + if err != nil { + return nil, err + } + msgCatalog.c.ChannelUUID = channelUUID + msgCatalog.c.Type = "msg_catalog" + catalog = append(catalog, msgCatalog) + } + + logrus.WithField("elapsed", time.Since(start)).WithField("org_id", orgID).WithField("count", len(catalog)).Debug("loaded catalog") + + return catalog, nil +} + +const selectOrgCatalogSQL = ` +SELECT ROW_TO_JSON(r) FROM (SELECT + c.id as id, + c.uuid as uuid, + c.facebook_catalog_id as facebook_catalog_id, + c.name as name, + c.created_on as created_on, + c.org_id as org_id, + c.modified_on as modified_on, + c.is_active as is_active, + c.channel_id as channel_id +FROM + public.wpp_products_catalog c +WHERE + c.org_id = $1 AND + c.is_active = TRUE +ORDER BY + c.created_on ASC +) r; +` + +// ChannelForChannelID returns the channel for the passed in channel ID if any +func ChannelUUIDForChannelID(ctx context.Context, db *sqlx.DB, channelID ChannelID) (assets.ChannelUUID, error) { + var channelUUID assets.ChannelUUID + err := db.GetContext(ctx, &channelUUID, `SELECT uuid FROM channels_channel WHERE id = $1 AND is_active = TRUE`, channelID) + if err != nil { + return assets.ChannelUUID(""), errors.Wrapf(err, "no channel found with id: %d", channelID) + } + return channelUUID, nil +} diff --git a/core/models/catalog_products_test.go b/core/models/catalog_products_test.go new file mode 100644 index 000000000..e4389dc53 --- /dev/null +++ b/core/models/catalog_products_test.go @@ -0,0 +1,61 @@ +package models_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/stretchr/testify/assert" +) + +func TestCatalogProducts(t *testing.T) { + ctx, _, db, _ := testsuite.Get() + defer testsuite.Reset(testsuite.ResetDB) + + // _, err := db.Exec(catalogProductDDL) + // if err != nil { + // t.Fatal(err) + // } + + _, err := db.Exec(`INSERT INTO public.wpp_products_catalog + (uuid, facebook_catalog_id, "name", created_on, modified_on, is_active, channel_id, org_id) + VALUES('2be9092a-1c97-4b24-906f-f0fbe3e1e93e', '123456789', 'Catalog Dummy', now(), now(), true, $1, $2); + `, testdata.Org2Channel.ID, testdata.Org2.ID) + assert.NoError(t, err) + + ctp, err := models.GetActiveCatalogFromChannel(ctx, *db, testdata.Org2Channel.ID) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, true, ctp.IsActive()) + + _, err = db.Exec(`INSERT INTO public.wpp_products_catalog + (uuid, facebook_catalog_id, "name", created_on, modified_on, is_active, channel_id, org_id) + VALUES('9bbe354d-cea6-408b-ba89-9ce28999da3f', '1234567891', 'Catalog Dummy2', now(), now(), false, $1, $2); + `, 123, testdata.Org2.ID) + fmt.Println(err) + assert.NoError(t, err) + + ctpn, err := models.GetActiveCatalogFromChannel(ctx, *db, 123) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, &models.MsgCatalog{}, ctpn) + +} + +func TestChannelUUIDForChannelID(t *testing.T) { + ctx, _, db, _ := testsuite.Get() + defer testsuite.Reset(testsuite.ResetAll) + + ctxp, cancelp := context.WithTimeout(ctx, time.Second*5) + defer cancelp() + + ctp, err := models.ChannelUUIDForChannelID(ctxp, db, testdata.TwilioChannel.ID) + assert.NoError(t, err) + assert.Equal(t, ctp, testdata.Org2Channel.UUID) +} diff --git a/core/models/channels.go b/core/models/channels.go index 838a0609f..d66e032a5 100644 --- a/core/models/channels.go +++ b/core/models/channels.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lib/pq" + "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" "github.com/nyaruka/mailroom/utils/dbutil" @@ -146,6 +147,40 @@ func GetChannelsByID(ctx context.Context, db Queryer, ids []ChannelID) ([]*Chann return channels, nil } +const selectActiveChannelByUUIDSQL = ` +SELECT ROW_TO_JSON(r) FROM (SELECT + c.id as id, + c.uuid as uuid, + c.name as name, + c.channel_type as channel_type, + COALESCE(c.tps, 10) as tps, + COALESCE(c.config, '{}')::json as config +FROM + channels_channel c +WHERE + c.uuid = $1 + and c.is_active = TRUE +) r; +` + +func GetActiveChannelByUUID(ctx context.Context, db Queryer, channelUUID uuids.UUID) (*Channel, error) { + rows, err := db.QueryxContext(ctx, selectActiveChannelByUUIDSQL, channelUUID) + if err != nil { + return nil, errors.Wrapf(err, "error querying channel by uuid") + } + defer rows.Close() + + channel := &Channel{} + for rows.Next() { + err := dbutil.ReadJSONRow(rows, &channel.c) + if err != nil { + return nil, errors.Wrapf(err, "error unmarshalling channel") + } + } + + return channel, nil +} + const selectChannelsByIDSQL = ` SELECT ROW_TO_JSON(r) FROM (SELECT c.id as id, diff --git a/core/models/msgs.go b/core/models/msgs.go index 6351d69bc..38457fb49 100644 --- a/core/models/msgs.go +++ b/core/models/msgs.go @@ -331,6 +331,11 @@ func NewOutgoingFlowMsg(rt *runtime.Runtime, org *Org, channel *Channel, session return newOutgoingMsg(rt, org, channel, session.ContactID(), out, createdOn, session, NilBroadcastID) } +// NewOutgoingFlowMsgCatalog creates an outgoing message for the passed in flow message +func NewOutgoingFlowMsgCatalog(rt *runtime.Runtime, org *Org, channel *Channel, session *Session, out *flows.MsgCatalogOut, createdOn time.Time) (*Msg, error) { + return newOutgoingMsgCatalog(rt, org, channel, session.ContactID(), out, createdOn, session, NilBroadcastID) +} + // NewOutgoingBroadcastMsg creates an outgoing message which is part of a broadcast func NewOutgoingBroadcastMsg(rt *runtime.Runtime, org *Org, channel *Channel, contactID ContactID, out *flows.MsgOut, createdOn time.Time, broadcastID BroadcastID) (*Msg, error) { return newOutgoingMsg(rt, org, channel, contactID, out, createdOn, nil, broadcastID) @@ -423,6 +428,89 @@ func newOutgoingMsg(rt *runtime.Runtime, org *Org, channel *Channel, contactID C return msg, nil } +func newOutgoingMsgCatalog(rt *runtime.Runtime, org *Org, channel *Channel, contactID ContactID, msgCatalog *flows.MsgCatalogOut, createdOn time.Time, session *Session, broadcastID BroadcastID) (*Msg, error) { + msg := &Msg{} + m := &msg.m + m.UUID = msgCatalog.UUID() + m.Text = msgCatalog.Text() + m.HighPriority = false + m.Direction = DirectionOut + m.Status = MsgStatusQueued + m.Visibility = VisibilityVisible + m.MsgType = MsgTypeFlow + m.MsgCount = 1 + m.ContactID = contactID + m.BroadcastID = broadcastID + m.OrgID = org.ID() + m.TopupID = NilTopupID + m.CreatedOn = createdOn + + msg.SetChannel(channel) + msg.SetURN(msgCatalog.URN()) + + if org.Suspended() { + // we fail messages for suspended orgs right away + m.Status = MsgStatusFailed + m.FailedReason = MsgFailedSuspended + } else if msg.URN() == urns.NilURN || channel == nil { + // if msg is missing the URN or channel, we also fail it + m.Status = MsgStatusFailed + m.FailedReason = MsgFailedNoDestination + } + + // if we have a session, set fields on the message from that + if session != nil { + m.ResponseToExternalID = session.IncomingMsgExternalID() + m.SessionID = session.ID() + m.SessionStatus = session.Status() + + // if we're responding to an incoming message, send as high priority + if session.IncomingMsgID() != NilMsgID { + m.HighPriority = true + } + } + + // populate metadata if we have any + if msgCatalog.Topic() != flows.NilMsgTopic || msgCatalog.TextLanguage != "" || msgCatalog.Header() != "" || msgCatalog.Body() != "" || msgCatalog.Footer() != "" || len(msgCatalog.Products()) != 0 { + metadata := make(map[string]interface{}) + if msgCatalog.Topic() != flows.NilMsgTopic { + metadata["topic"] = string(msgCatalog.Topic()) + } + if msgCatalog.TextLanguage != "" { + metadata["text_language"] = msgCatalog.TextLanguage + } + if msgCatalog.Header() != "" { + metadata["header"] = string(msgCatalog.Header()) + } + if msgCatalog.Body() != "" { + metadata["body"] = string(msgCatalog.Body()) + } + if msgCatalog.Footer() != "" { + metadata["footer"] = string(msgCatalog.Footer()) + } + if len(msgCatalog.Body()) != 0 { + metadata["products"] = msgCatalog.Products() + } + if msgCatalog.Action() != "" { + metadata["action"] = msgCatalog.Action() + } + if msgCatalog.Smart() { + metadata["send_catalog"] = false + } else { + metadata["send_catalog"] = msgCatalog.SendCatalog() + } + + m.Metadata = null.NewMap(metadata) + } + + // if we're sending to a phone, message may have to be sent in multiple parts + if m.URN.Scheme() == urns.TelScheme { + m.MsgCount = gsm7.Segments(m.Text) + len(m.Attachments) + } + + return msg, nil +} + // NewIncomingMsg creates a new incoming message for the passed in text and attachment func NewIncomingMsg(cfg *runtime.Config, orgID OrgID, channel *Channel, contactID ContactID, in *flows.MsgIn, createdOn time.Time) *Msg { msg := &Msg{} diff --git a/core/models/runs.go b/core/models/runs.go index 3158ca3b0..03ba1f7c5 100644 --- a/core/models/runs.go +++ b/core/models/runs.go @@ -94,8 +94,9 @@ var exitToRunStatusMap = map[ExitType]RunStatus{ } var keptEvents = map[string]bool{ - events.TypeMsgCreated: true, - events.TypeMsgReceived: true, + events.TypeMsgCreated: true, + events.TypeMsgCatalogCreated: true, + events.TypeMsgReceived: true, } // Session is the mailroom type for a FlowSession diff --git a/go.mod b/go.mod index b0c669599..4384e1268 100644 --- a/go.mod +++ b/go.mod @@ -70,4 +70,6 @@ go 1.17 replace github.com/nyaruka/gocommon => github.com/Ilhasoft/gocommon v1.16.2-weni -replace github.com/nyaruka/goflow => github.com/weni-ai/goflow v0.2.3-goflow-0.144.3 + +replace github.com/nyaruka/goflow => github.com/weni-ai/goflow v0.3.0-goflow-0.144.3 + diff --git a/go.sum b/go.sum index 01bc021ca..597bf4f67 100644 --- a/go.sum +++ b/go.sum @@ -204,8 +204,8 @@ github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLD github.com/tj/go-elastic v0.0.0-20171221160941-36157cbbebc2/go.mod h1:WjeM0Oo1eNAjXGDx2yma7uG2XoyRZTq1uv3M/o7imD0= github.com/tj/go-kinesis v0.0.0-20171128231115-08b17f58cb1b/go.mod h1:/yhzCV0xPfx6jb1bBgRFjl5lytqVqZXEaeqWP8lTEao= github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKwh4= -github.com/weni-ai/goflow v0.2.3-goflow-0.144.3 h1:COYyomogTzlpPFbGfoA3y5hUG3KOYxaC0oa4jk6Kez8= -github.com/weni-ai/goflow v0.2.3-goflow-0.144.3/go.mod h1:o0xaVWP9qNcauBSlcNLa79Fm2oCPV+BDpheFRa/D40c= +github.com/weni-ai/goflow v0.3.0-goflow-0.144.3 h1:I6d3rcMBCS56McFVq0eAN17Nk8sPVDf/315uCosOWtM= +github.com/weni-ai/goflow v0.3.0-goflow-0.144.3/go.mod h1:o0xaVWP9qNcauBSlcNLa79Fm2oCPV+BDpheFRa/D40c= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/mailroom_test.dump b/mailroom_test.dump index 2d2169f80..3d8169331 100644 Binary files a/mailroom_test.dump and b/mailroom_test.dump differ diff --git a/runtime/config.go b/runtime/config.go index 4162e66dc..c97bbb2c1 100644 --- a/runtime/config.go +++ b/runtime/config.go @@ -79,6 +79,14 @@ type Config struct { ZeroshotAPIToken string `help:"secret token for zeroshot API authentication and authorization"` ZeroshotAPIUrl string `help:"zeroshot API base url"` + ChatgptKey string `help:"chat gpt api key"` + ChatgptBaseURL string `help:"chat gpt base url"` + + WenigptAuthToken string `help:"wenigpt authorization token"` + WenigptCookie string `help:"wenigpt cookie"` + WenigptBaseURL string `help:"wenigpt url"` + + SentenxBaseURL string `help:"sentenx base url"` } // NewDefaultConfig returns a new default configuration object @@ -136,6 +144,8 @@ func NewDefaultConfig() *Config { ZeroshotAPIToken: "", ZeroshotAPIUrl: "http://engine-ai.dev.cloud.weni.ai", + ChatgptBaseURL: "https://api.openai.com", + SentenxBaseURL: "https://sentenx.weni.ai", } } diff --git a/services/external/weni/sentenx/client.go b/services/external/weni/sentenx/client.go new file mode 100644 index 000000000..950bc1ef1 --- /dev/null +++ b/services/external/weni/sentenx/client.go @@ -0,0 +1,108 @@ +package sentenx + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/gocommon/jsonx" + "github.com/pkg/errors" +) + +const BaseURL = "https://sentenx.weni.ai" + +type SearchRequest struct { + Search string `json:"search,omitempty"` + Filter struct { + CatalogID string `json:"catalog_id,omitempty"` + } `json:"filter,omitempty"` + Threshold float64 `json:"threshold,omitempty"` +} + +func NewSearchRequest(search, catalogID string, threshold float64) *SearchRequest { + return &SearchRequest{ + Search: search, + Filter: struct { + CatalogID string `json:"catalog_id,omitempty"` + }{ + CatalogID: catalogID, + }, + Threshold: threshold, + } +} + +type SearchResponse struct { + Products []Product `json:"products,omitempty"` +} + +type Product struct { + ProductRetailerID string `json:"product_retailer_id,omitempty"` +} + +type ErrorResponse struct { + Detail []struct { + Msg string `json:"msg,omitempty"` + } `json:"detail,omitempty"` +} + +type Client struct { + httpClient *http.Client + httpRetries *httpx.RetryConfig + baseURL string +} + +func NewClient(httpClient *http.Client, httpRetries *httpx.RetryConfig, baseURL string) *Client { + return &Client{httpClient, httpRetries, baseURL} +} + +func (c *Client) Request(method, url string, body, response interface{}) (*httpx.Trace, error) { + b, err := json.Marshal(body) + if err != nil { + return nil, err + } + + data := strings.NewReader(string(b)) + req, err := httpx.NewRequest(method, url, data, nil) + if err != nil { + return nil, err + } + + trace, err := httpx.DoTrace(c.httpClient, req, c.httpRetries, nil, -1) + if err != nil { + return trace, err + } + + if trace.Response.StatusCode >= 400 { + var errorResponse ErrorResponse + err = jsonx.Unmarshal(trace.ResponseBody, &errorResponse) + if err != nil { + return trace, err + } + concatenatedErrorMsg := "" + for i, msg := range errorResponse.Detail { + concatenatedErrorMsg += msg.Msg + if i < len(errorResponse.Detail)-1 { + concatenatedErrorMsg += ". " + } + } + return trace, errors.New(concatenatedErrorMsg) + } + + if response != nil { + err := json.Unmarshal(trace.ResponseBody, response) + return trace, errors.Wrap(err, "couldn't unmarshal response body") + } + return trace, nil +} + +func (c *Client) SearchProducts(data *SearchRequest) (*SearchResponse, *httpx.Trace, error) { + requestURL := c.baseURL + "/products/search" + response := &SearchResponse{} + + trace, err := c.Request("GET", requestURL, data, response) + if err != nil { + return nil, trace, err + } + return response, trace, nil +} diff --git a/services/external/weni/sentenx/client_test.go b/services/external/weni/sentenx/client_test.go new file mode 100644 index 000000000..d2cec7999 --- /dev/null +++ b/services/external/weni/sentenx/client_test.go @@ -0,0 +1,90 @@ +package sentenx_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/mailroom/services/external/weni/sentenx" + "github.com/stretchr/testify/assert" +) + +const ( + baseURL = "https://sentenx.weni.ai" +) + +func TestRequest(t *testing.T) { + client := sentenx.NewClient(http.DefaultClient, nil, baseURL) + + _, err := client.Request("POST", "", func() {}, nil) + assert.Error(t, err) + + _, err = client.Request("{[:INVALID:]}", "", nil, nil) + assert.Error(t, err) + + defer httpx.SetRequestor(httpx.DefaultRequestor) + httpx.SetRequestor(httpx.NewMockRequestor(map[string][]httpx.MockResponse{ + baseURL: { + httpx.NewMockResponse(400, nil, `[]`), + httpx.NewMockResponse(200, nil, `{}`), + httpx.NewMockResponse(400, nil, `{ + "detail": [ + { "msg": "dummy error message"} + ] + }`), + }, + })) + + _, err = client.Request("GET", baseURL, nil, nil) + assert.Error(t, err) + + _, err = client.Request("GET", baseURL, nil, nil) + assert.Nil(t, err) + + response := new(interface{}) + _, err = client.Request("GET", baseURL, nil, response) + assert.Error(t, err) +} + +func TestSearch(t *testing.T) { + defer httpx.SetRequestor(httpx.DefaultRequestor) + httpx.SetRequestor(httpx.NewMockRequestor(map[string][]httpx.MockResponse{ + fmt.Sprintf("%s/products/search", baseURL): { + httpx.NewMockResponse(400, nil, `{ + "detail": [{"msg": "dummy error msg"}, {"msg": "dummy error msg 2"}] + }`), + httpx.NewMockResponse(200, nil, `{ + "products": [ + { + "facebook_id": "1234567891", + "title": "banana prata 1kg", + "org_id": "1", + "channel_id": "5", + "catalog_id": "asdfgh", + "product_retailer_id": "p1" + }, + { + "facebook_id": "1234567892", + "title": "doce de banana 250g", + "org_id": "1", + "channel_id": "5", + "catalog_id": "asdfgh", + "product_retailer_id": "p2" + } + ] + }`), + }, + })) + + client := sentenx.NewClient(http.DefaultClient, nil, baseURL) + + data := sentenx.NewSearchRequest("banana", "asdfgh", 1.6) + + _, _, err := client.SearchProducts(data) + assert.EqualError(t, err, "dummy error msg. dummy error msg 2") + + sres, _, err := client.SearchProducts(data) + assert.NoError(t, err) + assert.Equal(t, "p1", sres.Products[0].ProductRetailerID) +} diff --git a/services/external/weni/service.go b/services/external/weni/service.go new file mode 100644 index 000000000..f9a14da23 --- /dev/null +++ b/services/external/weni/service.go @@ -0,0 +1,211 @@ +package catalogs + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "sync" + "time" + + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/utils" + "github.com/nyaruka/mailroom/core/goflow" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/mailroom/services/external/openai/chatgpt" + "github.com/nyaruka/mailroom/services/external/weni/sentenx" + "github.com/nyaruka/mailroom/services/external/weni/wenigpt" + "github.com/pkg/errors" +) + +const ( + serviceType = "msg_catalog" +) + +var db *sqlx.DB +var mu = &sync.Mutex{} + +func initDB(dbURL string) error { + mu.Lock() + defer mu.Unlock() + if db == nil { + newDB, err := sqlx.Open("postgres", dbURL) + if err != nil { + return errors.Wrap(err, "unable to open database connection") + } + SetDB(newDB) + } + return nil +} + +func SetDB(newDB *sqlx.DB) { + db = newDB +} + +func init() { + models.RegisterMsgCatalogService(serviceType, NewService) +} + +type service struct { + rtConfig *runtime.Config + restClient *http.Client + redactor utils.Redactor +} + +func NewService(rtCfg *runtime.Config, httpClient *http.Client, httpRetries *httpx.RetryConfig, msgCatalog *flows.MsgCatalog, config map[string]string) (models.MsgCatalogService, error) { + + if err := initDB(rtCfg.DB); err != nil { + return nil, err + } + + return &service{ + rtConfig: rtCfg, + restClient: httpClient, + redactor: utils.NewRedactor(flows.RedactionMask), + }, nil +} + +func (s *service) Call(session flows.Session, params assets.MsgCatalogParam, logHTTP flows.HTTPLogCallback) (*flows.MsgCatalogCall, error) { + callResult := &flows.MsgCatalogCall{} + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + content := params.ProductSearch + productList, traceWeniGPT, err := GetProductListFromChatGPT(ctx, s.rtConfig, content) + callResult.TraceWeniGPT = traceWeniGPT + if err != nil { + return callResult, err + } + channelUUID := params.ChannelUUID + channel, err := models.GetActiveChannelByUUID(ctx, db, channelUUID) + if err != nil { + return callResult, err + } + + catalog, err := models.GetActiveCatalogFromChannel(ctx, *db, channel.ID()) + if err != nil { + return callResult, err + } + channelThreshold := channel.ConfigValue("threshold", "1.5") + searchThreshold, err := strconv.ParseFloat(channelThreshold, 64) + if err != nil { + return callResult, err + } + + productRetailerIDS := []string{} + + for _, product := range productList { + searchResult, trace, err := GetProductListFromSentenX(product, catalog.FacebookCatalogID(), searchThreshold, s.rtConfig) + callResult.TraceSentenx = trace + if err != nil { + return callResult, errors.Wrapf(err, "on iterate to search products on sentenx") + } + for _, prod := range searchResult { + productRetailerIDS = append(productRetailerIDS, prod["product_retailer_id"]) + } + } + + callResult.ProductRetailerIDS = productRetailerIDS + + return callResult, nil +} + +func GetProductListFromWeniGPT(rtConfig *runtime.Config, content string) ([]string, *httpx.Trace, error) { + httpClient, httpRetries, _ := goflow.HTTP(rtConfig) + weniGPTClient := wenigpt.NewClient(httpClient, httpRetries, rtConfig.WenigptBaseURL, rtConfig.WenigptAuthToken, rtConfig.WenigptCookie) + + prompt := fmt.Sprintf(`Give me an unformatted JSON list containing strings with the name of each product taken from the user prompt. Never repeat the same product. Always return a valid json using this pattern: {\"products\": []} Request: %s. Response:`, content) + + dr := wenigpt.NewWenigptRequest( + prompt, + 0, + 0.0, + 0.0, + true, + wenigpt.DefaultStopSequences, + ) + + response, trace, err := weniGPTClient.WeniGPTRequest(dr) + if err != nil { + return nil, trace, errors.Wrapf(err, "error on wenigpt call fot list products") + } + + productsJson := response.Output.Text[0] + + var products map[string][]string + err = json.Unmarshal([]byte(productsJson), &products) + if err != nil { + return nil, trace, errors.Wrapf(err, "error on unmarshalling product list") + } + return products["products"], trace, nil +} + +func GetProductListFromSentenX(productSearch string, catalogID string, threshold float64, rtConfig *runtime.Config) ([]map[string]string, *httpx.Trace, error) { + client := sentenx.NewClient(http.DefaultClient, nil, rtConfig.SentenxBaseURL) + + searchParams := sentenx.NewSearchRequest(productSearch, catalogID, threshold) + + searchResponse, trace, err := client.SearchProducts(searchParams) + if err != nil { + return nil, trace, err + } + + if len(searchResponse.Products) < 1 { + return nil, trace, errors.New("no products found on sentenx") + } + + pmap := make(map[string]struct{}) + for _, p := range searchResponse.Products { + pmap[p.ProductRetailerID] = struct{}{} + } + + result := []map[string]string{} + for k := range pmap { + mapElement := map[string]string{"product_retailer_id": k} + result = append(result, mapElement) + } + + return result, trace, nil +} + +func GetProductListFromChatGPT(ctx context.Context, rtConfig *runtime.Config, content string) ([]string, *httpx.Trace, error) { + httpClient, httpRetries, _ := goflow.HTTP(rtConfig) + chatGPTClient := chatgpt.NewClient(httpClient, httpRetries, rtConfig.ChatgptBaseURL, rtConfig.ChatgptKey) + + prompt1 := chatgpt.ChatCompletionMessage{ + Role: chatgpt.ChatMessageRoleSystem, + Content: "Give me an unformatted JSON list containing strings with the name of each product taken from the user prompt.", + } + prompt2 := chatgpt.ChatCompletionMessage{ + Role: chatgpt.ChatMessageRoleSystem, + Content: "Never repeat the same product.", + } + prompt3 := chatgpt.ChatCompletionMessage{ + Role: chatgpt.ChatMessageRoleSystem, + Content: "Always use this pattern: {\"products\": []}", + } + question := chatgpt.ChatCompletionMessage{ + Role: chatgpt.ChatMessageRoleUser, + Content: content, + } + completionRequest := chatgpt.NewChatCompletionRequest([]chatgpt.ChatCompletionMessage{prompt1, prompt2, prompt3, question}) + response, trace, err := chatGPTClient.CreateChatCompletion(completionRequest) + if err != nil { + return nil, trace, errors.Wrapf(err, "error on chatgpt call for list products") + } + + productsJson := response.Choices[0].Message.Content + + var products map[string][]string + err = json.Unmarshal([]byte(productsJson), &products) + if err != nil { + return nil, trace, errors.Wrapf(err, "error on unmarshalling product list") + } + return products["products"], trace, nil +} diff --git a/services/external/weni/service_test.go b/services/external/weni/service_test.go new file mode 100644 index 000000000..d0eb38993 --- /dev/null +++ b/services/external/weni/service_test.go @@ -0,0 +1,94 @@ +package catalogs_test + +import ( + "net/http" + "testing" + + "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/goflow/assets/static" + "github.com/nyaruka/goflow/envs" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/test" + catalogs "github.com/nyaruka/mailroom/services/external/weni" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService(t *testing.T) { + _, rt, _, _ := testsuite.Get() + + defer dates.SetNowSource(dates.DefaultNowSource) + session, _, err := test.CreateTestSession("", envs.RedactionPolicyNone) + require.NoError(t, err) + + defer uuids.SetGenerator(uuids.DefaultGenerator) + defer httpx.SetRequestor(httpx.DefaultRequestor) + + uuids.SetGenerator(uuids.NewSeededGenerator(12345)) + + httpx.SetRequestor(httpx.NewMockRequestor(map[string][]httpx.MockResponse{ + "https://api.openai.com/v1/chat/completions": { + httpx.NewMockResponse(200, nil, `{ + "id": "chatcmpl-7IfBIQsTVKbwOiHPgcrpthaCn7K1t", + "object": "chat.completion", + "created":1684682560, + "model":"gpt-3.5-turbo-0301", + "usage":{ + "prompt_tokens":26, + "completion_tokens":8, + "total_tokens":34 + }, + "choices":[ + { + "message":{ + "role":"assistant", + "content":"{\"products\": [\"banana\"]}" + }, + "finish_reason":"stop", + "index":0 + } + ] + }`), + }, + "https://sentenx.weni.ai/products/search": { + httpx.NewMockResponse(200, nil, `{ + "products": [ + { + "facebook_id": "1234567891", + "title": "banana prata 1kg", + "org_id": "1", + "channel_id": "10000", + "catalog_id": "123456789", + "product_retailer_id": "p1" + } + ] + }`), + }, + })) + + catalogService := flows.NewMsgCatalog(static.NewMsgCatalog(assets.MsgCatalogUUID(testdata.Org1.UUID), "msg_catalog", "msg_catalog", assets.ChannelUUID(uuids.New()))) + + svc, err := catalogs.NewService(rt.Config, http.DefaultClient, nil, catalogService, map[string]string{}) + + assert.NoError(t, err) + + logger := &flows.HTTPLogger{} + + params := assets.MsgCatalogParam{ + ProductSearch: "", + ChannelUUID: uuids.UUID(testdata.TwilioChannel.UUID), + } + call, err := svc.Call(session, params, logger.Log) + assert.NoError(t, err) + assert.NotNil(t, call) + + assert.Equal(t, "p1", call.ProductRetailerIDS[0]) + assert.NotNil(t, call.TraceWeniGPT) + assert.NotNil(t, call.TraceSentenx) + +} diff --git a/services/external/weni/wenigpt/client.go b/services/external/weni/wenigpt/client.go new file mode 100644 index 000000000..da61087f3 --- /dev/null +++ b/services/external/weni/wenigpt/client.go @@ -0,0 +1,141 @@ +package wenigpt + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/gocommon/jsonx" + "github.com/pkg/errors" +) + +const BaseURL = "https://api.runpod.ai/v2/y4dkssg660i2vp/runsync" + +var ( + defaultMaxNewTokens = int64(1000) + defaultTopP = float64(0.1) + defaultTemperature = float64(0.1) + DefaultStopSequences = []string{"Request", "Response"} +) + +type Input struct { + Prompt string `json:"prompt"` + SamplingParams SamplingParams `json:"sampling_params"` +} + +type Output struct { + Text []string `json:"text"` +} + +type SamplingParams struct { + MaxNewTokens int64 `json:"max_new_tokens"` + TopP float64 `json:"top_p"` + Temperature float64 `json:"temperature"` + DoSample bool `json:"do_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` +} + +type WeniGPTRequest struct { + Input Input `json:"input"` +} + +type WeniGPTResponse struct { + DelayTime int64 `json:"delayTime"` + ExecutionTime int64 `json:"executionTime"` + ID string `json:"id"` + Output Output `json:"output"` + Status string `json:"status"` +} + +type ErrorResponse struct { + Error string `json:"error"` +} + +type WeniGPTStatus string + +const ( + STATUS_COMPLETED = WeniGPTStatus("COMPLETED") + STATUS_IN_PROGRESS = WeniGPTStatus("IN_PROGRESS") +) + +func NewWenigptRequest(prompt string, maxNewTokens int64, topP float64, temperature float64, doSample bool, stopSequences []string) *WeniGPTRequest { + if maxNewTokens <= 0 { + maxNewTokens = defaultMaxNewTokens + } + if topP <= 0.0 { + topP = defaultTopP + } + if temperature <= 0.0 { + temperature = defaultTemperature + } + + return &WeniGPTRequest{ + Input: Input{ + Prompt: prompt, + SamplingParams: SamplingParams{ + MaxNewTokens: maxNewTokens, + TopP: topP, + Temperature: temperature, + DoSample: doSample, + StopSequences: stopSequences, + }, + }, + } +} + +type Client struct { + httpClient *http.Client + httpRetries *httpx.RetryConfig + baseURL string + authorization string + cookie string +} + +func NewClient(httpClient *http.Client, httpRetries *httpx.RetryConfig, baseURL, authorization, cookie string) *Client { + return &Client{httpClient, httpRetries, baseURL, authorization, cookie} +} + +func (c *Client) Request(method, url string, body, response interface{}) (*httpx.Trace, error) { + b, err := json.Marshal(body) + if err != nil { + return nil, err + } + data := strings.NewReader(string(b)) + req, err := httpx.NewRequest(method, url, data, nil) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+c.authorization) + req.Header.Add("Cookie", c.cookie) + + trace, err := httpx.DoTrace(c.httpClient, req, c.httpRetries, nil, -1) + if err != nil { + return trace, err + } + + if trace.Response.StatusCode >= 400 { + response := &ErrorResponse{} + jsonx.Unmarshal(trace.ResponseBody, response) + return trace, errors.New(response.Error) + } + + if response != nil { + err := json.Unmarshal(trace.ResponseBody, response) + return trace, errors.Wrap(err, "couldn't parse response body") + } + + return trace, nil +} + +func (c *Client) WeniGPTRequest(data *WeniGPTRequest) (*WeniGPTResponse, *httpx.Trace, error) { + requestURL := c.baseURL + response := &WeniGPTResponse{} + + trace, err := c.Request("POST", requestURL, data, response) + if err != nil { + return nil, trace, err + } + return response, trace, nil +} diff --git a/services/external/weni/wenigpt/client_test.go b/services/external/weni/wenigpt/client_test.go new file mode 100644 index 000000000..8e13d1ea3 --- /dev/null +++ b/services/external/weni/wenigpt/client_test.go @@ -0,0 +1,80 @@ +package wenigpt_test + +import ( + "net/http" + "testing" + + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/mailroom/services/external/weni/wenigpt" + "github.com/stretchr/testify/assert" +) + +const ( + baseURL = "https://wenigpt.weni.ai" + authorization = "098e5a87-7221-45ba-9f06-98d066fed8e5" + cookie = "4f01f95e-fe65-4484-92d6-d7bff41fa06e" +) + +func TestRequest(t *testing.T) { + client := wenigpt.NewClient(http.DefaultClient, nil, baseURL, authorization, cookie) + + _, err := client.Request("POST", "", func() {}, nil) + assert.Error(t, err) + + _, err = client.Request("{[:INVALID:]}", "", nil, nil) + assert.Error(t, err) + + defer httpx.SetRequestor(httpx.DefaultRequestor) + httpx.SetRequestor(httpx.NewMockRequestor(map[string][]httpx.MockResponse{ + baseURL: { + httpx.NewMockResponse(400, nil, `{ + "error": "dummy error message" + }`), + httpx.NewMockResponse(200, nil, `{}`), + httpx.NewMockResponse(400, nil, `{ + "error": "dummy error message" + }`), + }, + })) + + _, err = client.Request("POST", baseURL, nil, nil) + assert.Error(t, err) + + _, err = client.Request("POST", baseURL, nil, nil) + assert.Nil(t, err) + + response := new(interface{}) + _, err = client.Request("POST", baseURL, nil, response) + assert.Error(t, err) +} + +func TestWeniGPTRequest(t *testing.T) { + defer httpx.SetRequestor(httpx.DefaultRequestor) + httpx.SetRequestor(httpx.NewMockRequestor(map[string][]httpx.MockResponse{ + baseURL: { + httpx.NewMockResponse(400, nil, `{ + "error": "dummy error message" + }`), + httpx.NewMockResponse(200, nil, `{ + "delayTime": 2, + "executionTime": 2, + "id": "66b6a02c-b6e5-4e94-be8b-c631875b24d1", + "status": "COMPLETED", + "output": { + "text": ["banana"] + } + }`), + }, + })) + + client := wenigpt.NewClient(http.DefaultClient, nil, baseURL, authorization, cookie) + + data := wenigpt.NewWenigptRequest("Request: say wenigpt response output text. Response", 0, 0.0, 0.0, true, wenigpt.DefaultStopSequences) + + _, _, err := client.WeniGPTRequest(nil) + assert.EqualError(t, err, "dummy error message") + + wmsg, _, err := client.WeniGPTRequest(data) + assert.NoError(t, err) + assert.Equal(t, "banana", wmsg.Output.Text[0]) +}