diff --git a/internal/bot/bot.go b/internal/bot/bot.go index 3aa9570..e0a9b1b 100644 --- a/internal/bot/bot.go +++ b/internal/bot/bot.go @@ -84,7 +84,7 @@ func (b *Bot) monitorUsers() { } defer tx.Rollback() - rows, err := tx.Query("SELECT guild_id, user_id, username, notification_channel, last_post_id, last_stream_start, mention_role, avatar_location, avatar_location_updated_at, live_image_url, posts_enabled, live_enabled FROM monitored_users") + rows, err := tx.Query("SELECT guild_id, user_id, username, notification_channel, post_notification_channel, live_notification_channel, last_post_id, last_stream_start, mention_role, avatar_location, avatar_location_updated_at, live_image_url, posts_enabled, live_enabled FROM monitored_users") if err != nil { return err } @@ -96,6 +96,8 @@ func (b *Bot) monitorUsers() { UserID string Username string NotificationChannel string + PostNotificationChannel string + LiveNotificationChannel string LastPostID string LastStreamStart int64 MentionRole string @@ -106,7 +108,7 @@ func (b *Bot) monitorUsers() { LiveEnabled bool } - err := rows.Scan(&user.GuildID, &user.UserID, &user.Username, &user.NotificationChannel, &user.LastPostID, &user.LastStreamStart, &user.MentionRole, &user.AvatarLocation, &user.AvatarLocationUpdatedAt, &user.LiveImageURL, &user.PostsEnabled, &user.LiveEnabled) + err := rows.Scan(&user.GuildID, &user.UserID, &user.Username, &user.NotificationChannel, &user.PostNotificationChannel, &user.LiveNotificationChannel, &user.LastPostID, &user.LastStreamStart, &user.MentionRole, &user.AvatarLocation, &user.AvatarLocationUpdatedAt, &user.LiveImageURL, &user.PostsEnabled, &user.LiveEnabled) if err != nil { log.Printf("Error scanning row: %v", err) continue @@ -155,7 +157,12 @@ func (b *Bot) monitorUsers() { mention = fmt.Sprintf("<@&%s>", user.MentionRole) } - _, err = b.Session.ChannelMessageSendComplex(user.NotificationChannel, &discordgo.MessageSend{ + targetChannel := user.LiveNotificationChannel + if targetChannel == "" { + targetChannel = user.NotificationChannel + } + + _, err = b.Session.ChannelMessageSendComplex(targetChannel, &discordgo.MessageSend{ Content: mention, Embed: embedMsg, }) @@ -212,7 +219,12 @@ func (b *Bot) monitorUsers() { mention = fmt.Sprintf("<@&%s>", user.MentionRole) } - _, err = b.Session.ChannelMessageSendComplex(user.NotificationChannel, &discordgo.MessageSend{ + targetChannel := user.PostNotificationChannel + if targetChannel == "" { + targetChannel = user.NotificationChannel + } + + _, err = b.Session.ChannelMessageSendComplex(targetChannel, &discordgo.MessageSend{ Content: mention, Embed: embedMsg, }) diff --git a/internal/bot/commands.go b/internal/bot/commands.go index 83e6fe2..950d70e 100644 --- a/internal/bot/commands.go +++ b/internal/bot/commands.go @@ -99,6 +99,40 @@ func (b *Bot) registerCommands() { }, }, }, + { + Name: "setchannel", + Description: "Set notification channel for posts or live notifications", + Options: []*discordgo.ApplicationCommandOption{ + { + Type: discordgo.ApplicationCommandOptionString, + Name: "username", + Description: "Fansly username", + Required: true, + }, + { + Type: discordgo.ApplicationCommandOptionString, + Name: "type", + Description: "notification type", + Required: true, + Choices: []*discordgo.ApplicationCommandOptionChoice{ + { + Name: "Posts", + Value: "posts", + }, + { + Name: "Live", + Value: "live", + }, + }, + }, + { + Type: discordgo.ApplicationCommandOptionChannel, + Name: "channel", + Description: "The notification channel", + Required: true, + }, + }, + }, } _, err := b.Session.ApplicationCommandBulkOverwrite(b.Session.State.User.ID, "", commands) diff --git a/internal/bot/handlers.go b/internal/bot/handlers.go index a4a7099..dc53aa9 100644 --- a/internal/bot/handlers.go +++ b/internal/bot/handlers.go @@ -33,6 +33,8 @@ func (b *Bot) interactionCreate(s *discordgo.Session, i *discordgo.InteractionCr b.handleSetLiveImageCommand(s, i) case "toggle": b.handleToggleCommand(s, i) + case "setchannel": + b.handleSetChannelCommand(s, i) } } @@ -100,7 +102,7 @@ func (b *Bot) handleAddCommand(s *discordgo.Session, i *discordgo.InteractionCre } } - if isFollowing { + if !isFollowing { followErr := b.APIClient.FollowAccount(accountInfo.ID) if followErr != nil { log.Printf("Note: Could not follow %s: %v", username, followErr) @@ -121,11 +123,11 @@ func (b *Bot) handleAddCommand(s *discordgo.Session, i *discordgo.InteractionCre // Store the monitored user in the database err = b.retryDbOperation(func() error { - _, err := b.DB.Exec(` - INSERT OR REPLACE INTO monitored_users - (guild_id, user_id, username, notification_channel, last_post_id, last_stream_start, mention_role, avatar_location, avatar_location_updated_at, live_image_url, posts_enabled, live_enabled) - VALUES (?, ?, ?, ?, '', 0, ?, ?, ?, ?, 1, 1) - `, i.GuildID, accountInfo.ID, username, channel.ID, mentionRole, avatarLocation, time.Now().Unix(), "") + _, err = b.DB.Exec(` + INSERT OR REPLACE INTO monitored_users + (guild_id, user_id, username, notification_channel, post_notification_channel, live_notification_channel, last_post_id, last_stream_start, mention_role, avatar_location, avatar_location_updated_at, live_image_url, posts_enabled, live_enabled) + VALUES (?, ?, ?, ?, ?, ?, '', 0, ?, ?, ?, ?, 1, 1) + `, i.GuildID, accountInfo.ID, username, channel.ID, channel.ID, channel.ID, mentionRole, avatarLocation, time.Now().Unix(), "") return err }) @@ -332,6 +334,44 @@ func (b *Bot) handleToggleCommand(s *discordgo.Session, i *discordgo.Interaction b.respondToInteraction(s, i, fmt.Sprintf("%s notifications %s for %s", notifiType, status, username)) } +func (b *Bot) handleSetChannelCommand(s *discordgo.Session, i *discordgo.InteractionCreate) { + options := i.ApplicationCommandData().Options + username := options[0].StringValue() + notifType := options[1].StringValue() + channel := options[2].ChannelValue(s) + + var columnName string + switch notifType { + case "posts": + columnName = "post_notification_channel" + case "live": + columnName = "live_notification_channel" + default: + b.respondToInteraction(s, i, "Invalid notification type") + return + } + + query := fmt.Sprintf(` + UPDATE monitored_users + SET %s = ? + WHERE guild_id = ? AND username = ? + `, columnName) + + result, err := b.DB.Exec(query, channel.ID, i.GuildID, username) + if err != nil { + b.respondToInteraction(s, i, fmt.Sprintf("Error updating channel: %v", err)) + return + } + + rowsAffected, _ := result.RowsAffected() + if rowsAffected == 0 { + b.respondToInteraction(s, i, fmt.Sprintf("User %s not found", username)) + return + } + + b.respondToInteraction(s, i, fmt.Sprintf("Successfully set %s notification channel for %s to %s", notifType, username, channel.Mention())) +} + // Add this new helper function func (b *Bot) editInteractionResponse(s *discordgo.Session, i *discordgo.InteractionCreate, content string) { _, err := s.InteractionResponseEdit(i.Interaction, &discordgo.WebhookEdit{ diff --git a/internal/database/databse.go b/internal/database/databse.go index 385535e..c63c92d 100644 --- a/internal/database/databse.go +++ b/internal/database/databse.go @@ -9,6 +9,8 @@ import ( var DB *sql.DB +const currentVersion = 2 + func Init() { var err error DB, err = sql.Open("sqlite", "bot.db") @@ -16,7 +18,116 @@ func Init() { log.Fatal(err) } - createTables() + _, err = DB.Exec(` + CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER PRIMARY KEY + ) + `) + if err != nil { + log.Fatal(err) + } + + // Get current schema version + var version int + err = DB.QueryRow("SELECT version FROM schema_version").Scan(&version) + if err != nil { + // No version found, assume fresh install + _, err = DB.Exec("INSERT INTO schema_version (version) VALUES (0)") + if err != nil { + log.Fatal(err) + } + version = 0 + } + + // Run migrations + runMigrations(version) +} + +func runMigrations(currentDBVersion int) { + migrations := []func(*sql.DB) error{ + migrateToV1, + migrateToV2, + // Add new migrations here + } + + for i, migration := range migrations { + version := i + 1 + if version <= currentDBVersion { + continue + } + + log.Printf("Running migration to version %d", version) + err := migration(DB) + if err != nil { + log.Fatalf("Migration to version %d failed: %v", version, err) + } + + _, err = DB.Exec("UPDATE schema_version SET version = ?", version) + if err != nil { + log.Fatalf("Failed to update schema version: %v", err) + } + log.Printf("Migration to version %d completed", version) + } +} + +func migrateToV1(db *sql.DB) error { + // Initial schema + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS monitored_users ( + guild_id TEXT, + user_id TEXT, + username TEXT, + notification_channel TEXT, + last_post_id TEXT, + last_stream_start INTEGER, + mention_role TEXT, + avatar_location TEXT, + avatar_location_updated_at INTEGER, + live_image_url TEXT, + posts_enabled BOOLEAN DEFAULT 1, + live_enabled BOOLEAN DEFAULT 1, + PRIMARY KEY (guild_id, user_id) + ) + `) + return err +} + +func migrateToV2(db *sql.DB) error { + // Add separate notification channels + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Add new columns + _, err = tx.Exec(` + ALTER TABLE monitored_users + ADD COLUMN post_notification_channel TEXT; + `) + if err != nil { + return err + } + + _, err = tx.Exec(` + ALTER TABLE monitored_users + ADD COLUMN live_notification_channel TEXT; + `) + if err != nil { + return err + } + + // Set default values from existing notification_channel + _, err = tx.Exec(` + UPDATE monitored_users + SET post_notification_channel = notification_channel, + live_notification_channel = notification_channel + `) + if err != nil { + return err + } + + return tx.Commit() } func Close() { @@ -30,6 +141,8 @@ func createTables() { user_id TEXT, username TEXT, notification_channel TEXT, + post_notification_channel TEXT, + live_notification_channel TEXT, last_post_id TEXT, last_stream_start INTEGER, mention_role TEXT,