Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: finish task #13

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions XinxinAkuma/AI/AI.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package AI

import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"github.com/joho/godotenv"
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"
)

const (
xunfeiAIAPIUrl = "wss://spark-api.xf-yun.com/v3.5/chat"
)

// GenerateSum 通过WebSocket与AI模型交互以生成答案
func GenerateSum(question string, answers []string) (string, error) {
err := godotenv.Load("D:/ermian/Akuma/secret.env")
if err != nil {
log.Fatal("Error loading .env file")
}

apiKey := os.Getenv("API_KEY")
apiSecret := os.Getenv("API_SECRET")
appId := os.Getenv("APP_ID")

d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
// 握手并建立websocket连接
conn, resp, err := d.Dial(assembleAuthUrl1(xunfeiAIAPIUrl, apiKey, apiSecret), nil)
if err != nil {
return "", fmt.Errorf("连接失败: %s, %v", readResp(resp), err)
}
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {

}
Comment on lines +43 to +47
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle Errors in Deferred Function

The deferred function closing the WebSocket connection does not handle potential errors from conn.Close(). Ignoring errors can make debugging difficult if issues arise during connection closure.

Apply this diff to handle any errors when closing the connection:

 defer func(conn *websocket.Conn) {
 	err := conn.Close()
 	if err != nil {
-
+		fmt.Printf("Error closing WebSocket connection: %v\n", err)
 	}
 }(conn) // Ensure the connection is closed when the function ends

This change will log any errors encountered when attempting to close the connection.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
}
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
fmt.Printf("Error closing WebSocket connection: %v\n", err)
}
}(conn) // Ensure the connection is closed when the function ends

}(conn) // 确保在函数结束时关闭连接

// 将所有的答案用 | 符号连接起来
joinedAnswers := strings.Join(answers, "| ")

// 构造最终的提示词
prompt := fmt.Sprintf("我会给你一个问题和一组用 | 符号分隔的答案,帮我总结一个完整的回答,不要带有自己的评论和分析。 问题: %s\n答案: %s", question, joinedAnswers)
Comment on lines +50 to +54
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve Prompt Clarity for the AI Model

The prompt constructed for the AI model could be clearer to ensure it generates the desired response. Providing precise instructions can improve the quality of the AI's output.

Consider rephrasing the prompt for better clarity:

 // 构造最终的提示词
-prompt := fmt.Sprintf("我会给你一个问题和一组用 | 符号分隔的答案,帮我总结一个完整的回答,不要带有自己的评论和分析。 问题: %s\n答案: %s", question, joinedAnswers)
+prompt := fmt.Sprintf("请根据以下问题和提供的多个答案,总结成一个完整的回答,不要添加任何评论或分析。\n问题:%s\n答案:%s", question, joinedAnswers)

This rephrased prompt provides clear instructions, which can help the AI model generate the expected summary.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// 将所有的答案用 | 符号连接起来
joinedAnswers := strings.Join(answers, "| ")
// 构造最终的提示词
prompt := fmt.Sprintf("我会给你一个问题和一组用 | 符号分隔的答案,帮我总结一个完整的回答,不要带有自己的评论和分析。 问题: %s\n答案: %s", question, joinedAnswers)
// 将所有的答案用 | 符号连接起来
joinedAnswers := strings.Join(answers, "| ")
// 构造最终的提示词
prompt := fmt.Sprintf("请根据以下问题和提供的多个答案,总结成一个完整的回答,不要添加任何评论或分析。\n问题:%s\n答案%s", question, joinedAnswers)

data := genParams1(appId, prompt)

// 发送数据
if err := conn.WriteJSON(data); err != nil {
return "", fmt.Errorf("发送数据失败: %v", err)
}

var answer string

// 获取返回的数据
for {
_, msg, err := conn.ReadMessage()
if err != nil {
return "", fmt.Errorf("读取消息失败: %v", err)
}

var data map[string]interface{}
if err := json.Unmarshal(msg, &data); err != nil {
return "", fmt.Errorf("解析JSON失败: %v", err)
}
// 解析数据
payload, ok := data["payload"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("无效的payload格式")
}
choices, ok := payload["choices"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("无效的choices格式")
}
header, ok := data["header"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("无效的header格式")
}
code, ok := header["code"].(float64)
if !ok || code != 0 {
return "", fmt.Errorf("错误的响应代码: %v", data["payload"])
}

status, ok := choices["status"].(float64)
if !ok {
return "", fmt.Errorf("无效的status格式")
}
text, ok := choices["text"].([]interface{})
if !ok {
return "", fmt.Errorf("无效的text格式")
}
content, ok := text[0].(map[string]interface{})["content"].(string)
if !ok {
return "", fmt.Errorf("无效的content格式")
}

if status != 2 {
answer += content
} else {
answer += content
usage, ok := payload["usage"].(map[string]interface{})
if ok {
temp, ok := usage["text"].(map[string]interface{})
if ok {
totalTokens, ok := temp["total_tokens"].(float64)
if ok {
fmt.Println("total_tokens:", totalTokens)
}
}
}
break
}
Comment on lines +71 to +121
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify JSON Parsing with Structs

Manually parsing JSON using nested maps and type assertions can be error-prone and hard to maintain. Defining Go structs that mirror the JSON response structure can simplify parsing and improve code readability.

Define response structs and update the parsing logic:

type AIResponse struct {
	Header struct {
		Code float64 `json:"code"`
	} `json:"header"`
	Payload struct {
		Choices struct {
			Status float64 `json:"status"`
			Text   []struct {
				Content string `json:"content"`
			} `json:"text"`
		} `json:"choices"`
		Usage struct {
			Text struct {
				TotalTokens float64 `json:"total_tokens"`
			} `json:"text"`
		} `json:"usage"`
	} `json:"payload"`
}

Update the parsing section:

 var answer string

 // 获取返回的数据
 for {
 	_, msg, err := conn.ReadMessage()
 	if err != nil {
 		return "", fmt.Errorf("读取消息失败: %v", err)
 	}

-	var data map[string]interface{}
-	if err := json.Unmarshal(msg, &data); err != nil {
+	var aiResp AIResponse
+	if err := json.Unmarshal(msg, &aiResp); err != nil {
 		return "", fmt.Errorf("解析JSON失败: %v", err)
 	}

-	// Existing parsing logic with multiple type assertions...
+	// Check for errors in the response
+	if aiResp.Header.Code != 0 {
+		return "", fmt.Errorf("错误的响应代码: %v", aiResp.Header.Code)
+	}

+	// Append the content to the answer
+	for _, text := range aiResp.Payload.Choices.Text {
+		answer += text.Content
+	}

+	// Check if the response is complete
+	if aiResp.Payload.Choices.Status == 2 {
+		if aiResp.Payload.Usage.Text.TotalTokens > 0 {
+			fmt.Println("total_tokens:", aiResp.Payload.Usage.Text.TotalTokens)
+		}
+		break
+	}
 }

This refactoring enhances code clarity and reduces the risk of runtime errors due to incorrect type assertions.

}

// 输出返回结果
return answer, nil
}

// 生成参数
func genParams1(appid, question string) map[string]interface{} { // 根据实际情况修改返回的数据结构和字段名

messages := []Message{
{Role: "user", Content: question},
}

data := map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"header": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"app_id": appid, // 根据实际情况修改返回的数据结构和字段名
},
"parameter": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"chat": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"domain": "general", // 根据实际情况修改返回的数据结构和字段名
"temperature": float64(0.8), // 根据实际情况修改返回的数据结构和字段名
"top_k": int64(6), // 根据实际情况修改返回的数据结构和字段名
"max_tokens": int64(2048), // 根据实际情况修改返回的数据结构和字段名
"auditing": "default", // 根据实际情况修改返回的数据结构和字段名
},
},
"payload": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"message": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"text": messages, // 根据实际情况修改返回的数据结构和字段名
},
},
}
return data // 根据实际情况修改返回的数据结构和字段名
}

// 创建鉴权url apikey 即 hmac username
func assembleAuthUrl1(hosturl string, apiKey, apiSecret string) string {
ul, err := url.Parse(hosturl)
if err != nil {
fmt.Println(err)
}
//签名时间
date := time.Now().UTC().Format(time.RFC1123)
//date = "Tue, 28 May 2019 09:10:42 MST"
//参与签名的字段 host ,date, request-line
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
sgin := strings.Join(signString, "\n")
// fmt.Println(sgin)
//签名结果
sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret)
// fmt.Println(sha)
//构建请求参数 此时不需要urlencoding
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))

v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
callurl := hosturl + "?" + v.Encode()
return callurl
}

func HmacWithShaTobase64(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
Comment on lines +190 to +193
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle Potential Errors in HMAC Calculation

While it's unlikely, the mac.Write() function can return an error. Handling this error ensures that your code is robust and future-proof.

Modify the function to handle the error:

 func HmacWithShaTobase64(algorithm, data, key string) string {
 	mac := hmac.New(sha256.New, []byte(key))
-	mac.Write([]byte(data))
+	if _, err := mac.Write([]byte(data)); err != nil {
+		fmt.Printf("Error writing data to HMAC: %v\n", err)
+		return ""
+	}
 	encodeData := mac.Sum(nil)
 	return base64.StdEncoding.EncodeToString(encodeData)
 }

This addition will log any errors during HMAC calculation and prevent unexpected crashes.

Committable suggestion was skipped due to low confidence.

}

func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
Comment on lines +202 to +203
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid Using panic for Error Handling

Using panic in non-critical sections can cause the entire application to crash unexpectedly. It's better to handle errors gracefully.

Apply this diff to return an empty response and log the error:

 func readResp(resp *http.Response) string {
 	if resp == nil {
 		return ""
 	}
 	b, err := io.ReadAll(resp.Body)
 	if err != nil {
-		panic(err)
+		fmt.Printf("Error reading response body: %v\n", err)
+		return ""
 	}
 	return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
 }

This change ensures that the application can handle the error without crashing.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
panic(err)
}
fmt.Printf("Error reading response body: %v\n", err)
return ""
}

return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
93 changes: 93 additions & 0 deletions XinxinAkuma/auth/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package auth

import (
"errors"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"net/http"
"os"
"strings"
"time"
)

// 秘钥
var jwtSecret = []byte(os.Getenv("JWT_KEY"))

// 定义 JWT 的声明结构
type Claims struct {
UserID uint `json:"user_id"`
UserName string `json:"name"`
jwt.RegisteredClaims
}

// 生成JWT Token
func GenerateToken(userID uint, userName string) (string, error) {
// 定义 Token 的声明,包含用户信息和到期时间
claims := &Claims{
UserID: userID,
UserName: userName,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 72)), // Token 有效期 72 小时
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making the token expiration duration configurable.

Currently, the token expiration is hardcoded to 72 hours. Making it configurable enhances flexibility and allows you to adjust expiration without changing code.

Apply this diff to make the expiration time configurable via an environment variable:

+import (
+	"os"
+	"strconv"
+)

 // Generate JWT Token
 func GenerateToken(userID uint, userName string) (string, error) {
 	// Define token claims, including user information and expiration time
+	expirationHours := 72 // default expiration
+	if envExpiry := os.Getenv("TOKEN_EXPIRATION_HOURS"); envExpiry != "" {
+		if hours, err := strconv.Atoi(envExpiry); err == nil {
+			expirationHours = hours
+		}
+	}

 	claims := &Claims{
 		UserID:   userID,
 		UserName: userName,
 		RegisteredClaims: jwt.RegisteredClaims{
-			ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 72)), // Token valid for 72 hours
+			ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(expirationHours) * time.Hour)),
 		},
 	}

Don't forget to handle any potential errors and ensure that the environment variable TOKEN_EXPIRATION_HOURS is set appropriately.

Committable suggestion was skipped due to low confidence.

},
}

// 创建带有声明的 Token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

// 签署 Token 并返回
return token.SignedString(jwtSecret)
}

// 验证JWT的中间件
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求头中的 Authorization 字段
tokenString := strings.TrimSpace(c.GetHeader("Authorization"))
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权,请登录"})
c.Abort()
return
}
Comment on lines +47 to +50
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor repetitive error handling into a helper function.

Multiple blocks in your middleware handle errors in a similar way. Refactoring them into a helper function reduces code duplication and enhances readability.

Create a helper function for unauthorized responses:

 func AuthMiddleware() gin.HandlerFunc {
+	// Helper function for unauthorized responses
+	unauthorized := func(c *gin.Context, message string) {
+		c.JSON(http.StatusUnauthorized, gin.H{"error": message})
+		c.Abort()
+	}

 	return func(c *gin.Context) {
 		// Get the Authorization header
 		tokenString := strings.TrimSpace(c.GetHeader("Authorization"))
 		if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
-			c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权,请登录"})
-			c.Abort()
-			return
+			unauthorized(c, "未授权,请登录")
+			return
 		}

 		// ... existing code ...

 		// Check for token parsing errors or invalid token
 		if err != nil {
 			if ve, ok := err.(*jwt.ValidationError); ok {
 				if ve.Errors&jwt.ValidationErrorMalformed != 0 {
-					c.JSON(http.StatusUnauthorized, gin.H{"error": "Token格式错误"})
-					c.Abort()
-					return
+					unauthorized(c, "Token格式错误")
+					return
 				} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
-					c.JSON(http.StatusUnauthorized, gin.H{"error": "Token已过期"})
-					c.Abort()
-					return
+					unauthorized(c, "Token已过期")
+					return
 				} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
-					c.JSON(http.StatusUnauthorized, gin.H{"error": "Token尚未生效"})
-					c.Abort()
-					return
+					unauthorized(c, "Token尚未生效")
+					return
 				} else {
-					c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"})
-					c.Abort()
-					return
+					unauthorized(c, "无效的Token")
+					return
 				}
-				c.Abort()
-				return
 			}
 		}

 		// ... existing code ...

 		} else {
-			c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"})
-			c.Abort()
+			unauthorized(c, "无效的Token")
 			return
 		}

This simplifies your middleware and makes it easier to manage error responses.

Also applies to: 64-77, 85-87


// 移除 Bearer 前缀
tokenString = strings.TrimSpace(strings.TrimPrefix(tokenString, "Bearer "))

// 解析 Token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle unexpected errors during token parsing.

Currently, if an error occurs that's not a jwt.ValidationError, it will be ignored. It's important to handle all possible errors to ensure proper error reporting.

Modify the error handling to catch unexpected errors:

 			// Check for token parsing errors or invalid token
 			if err != nil {
 				if ve, ok := err.(*jwt.ValidationError); ok {
 					// ... existing validation error handling ...
+				} else {
+					unauthorized(c, "Token解析错误")
+					return
 				}
+			} else {
+				unauthorized(c, "无效的Token")
+				return
 			}

This ensures that any unexpected parsing errors are also communicated to the client.

Committable suggestion was skipped due to low confidence.

if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("无效的签名方法")
}
return jwtSecret, nil
})
Comment on lines +57 to +61
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Explicitly check the signing algorithm to prevent security risks.

Using type assertion to check the signing method may not be sufficient. It's more secure to compare the algorithm explicitly to avoid algorithm substitution attacks.

Apply this diff to enhance the security of signing method validation:

 			// Parse Token
 			token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
-				if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+				if token.Method.Alg() != jwt.SigningMethodHS256.Alg() {
 					return nil, errors.New("无效的签名方法")
 				}
 				return jwtSecret, nil
 			})

This ensures that only tokens signed with HS256 are accepted.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("无效的签名方法")
}
return jwtSecret, nil
})
if token.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, errors.New("无效的签名方法")
}
return jwtSecret, nil
})


// 检查 Token 解析是否出错或者无效
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token格式错误"})
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token已过期"})
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token尚未生效"})
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"})
}
c.Abort()
return
}
}

// 检查 Token 是否有效
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 将用户信息保存在上下文中
c.Set("user_id", claims.UserID)
c.Set("user_name", claims.UserName)
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"})
c.Abort()
return
}

c.Next() // 继续执行下一个处理器
}
}
52 changes: 52 additions & 0 deletions XinxinAkuma/database1/userdatabase.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package database1

import (
"fmt"
"github.com/joho/godotenv"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"log"
"os"
)

var DB *gorm.DB
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider encapsulating the DB variable.

While it's common to have a package-level database variable, exposing it directly can lead to issues with encapsulation and make it harder to manage database access across the application. Consider making DB unexported (lowercase db) and providing methods to interact with it instead.


// InitDB 初始化数据库连接
func InitDB() {
er := godotenv.Load("D:/ermian/Akuma/secret.env")
if er != nil {
log.Fatal("Error loading .env file")
}

dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_1"))
var err error
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
log.Fatal("无法连接到数据库:", err)
}
log.Println("数据库连接成功")

// 测试数据库连接
sqlDB, err := DB.DB()
if err != nil {
log.Fatal("获取数据库实例失败:", err)
}

// Ping 数据库
if err := sqlDB.Ping(); err != nil {
log.Fatal("数据库连接失败:", err)
}
log.Println("数据库连接测试成功")
}

func AutoMigrate(models ...interface{}) {
err := DB.AutoMigrate(models...)
if err != nil {
log.Fatal("自动迁移失败:", err)
}
}
Comment on lines +47 to +52
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance the AutoMigrate function.

  1. Return errors instead of using log.Fatal:
    Similar to the InitDB function, consider returning errors instead of terminating the program. This allows the caller to decide how to handle migration failures.

    func AutoMigrate(models ...interface{}) error {
        if err := DB.AutoMigrate(models...); err != nil {
            return fmt.Errorf("auto migration failed: %w", err)
        }
        return nil
    }
  2. Consider adding options for migration customization:
    You might want to add options to customize the migration behavior, such as allowing to disable foreign key constraint checks during migration.

    func AutoMigrate(models ...interface{}) error {
        if err := DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(models...); err != nil {
            return fmt.Errorf("auto migration failed: %w", err)
        }
        return nil
    }

Loading