Skip to content

Commit

Permalink
✨ add proxy option
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaomaoniu committed Jan 28, 2024
1 parent b664411 commit 0071cc1
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 479 deletions.
63 changes: 22 additions & 41 deletions nonebot_plugin_gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import os
import aiohttp
import google.generativeai as genai

from io import BytesIO
from PIL import Image as PILImage
from typing import Union
from pathlib import Path
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.adapters import Message, Event, Bot
from nonebot import require, get_driver, on_command
from nonebot.params import CommandArg, ArgPlainText
from google.generativeai.generative_models import ChatSession
from nonebot.plugin import PluginMetadata, inherit_supported_adapters

from .config import Config
from .gemini import Gemini, GeminiChatSession

require("nonebot_plugin_alconna")
require("nonebot_plugin_htmlrender")
Expand Down Expand Up @@ -42,40 +42,27 @@
raise ValueError("GOOGLE_API_KEY 未配置, nonebot-plugin-gemini 无法运行")


genai.configure(api_key=GOOGLE_API_KEY)

models = {
"gemini-pro": genai.GenerativeModel("gemini-pro"),
"gemini-pro-vision": genai.GenerativeModel("gemini-pro-vision"),
}
gemini = Gemini(GOOGLE_API_KEY, plugin_config.proxy)


async def to_markdown(text: str) -> bytes:
text = text.replace("•", " *")
return await md_to_pic(text, width=800)


async def to_pil_image(image: Image) -> PILImage:
async def to_image_data(image: Image) -> Union[BytesIO, bytes]:
if image.raw is not None:
return PILImage.open(
image.raw.getvalue() if isinstance(image.raw, BytesIO) else image.raw
)

try:
return PILImage.open(image.raw_bytes)
except ValueError:
pass
return image.raw

if image.path is not None:
return PILImage.open(image.path)
return Path(image.path).read_bytes()

if image.url is not None:
async with aiohttp.ClientSession() as session:
async with session.get(image.url) as resp:
data = await resp.read()
return PILImage.open(BytesIO(data))
return await resp.read()

raise ValueError("无法获取图片")
raise ValueError("无法获取图片数据")


chat = on_command("gemini", priority=10, block=True)
Expand All @@ -87,34 +74,30 @@ async def _(event: Event, bot: Bot, message: Message = CommandArg()):
uni_message = await UniMessage.generate(message=message, event=event, bot=bot)

msg = []
model = "gemini-pro"

for seg in uni_message:
if isinstance(seg, Text):
msg.append(seg.text)

elif isinstance(seg, Image):
model = "gemini-pro-vision"
msg.append(await to_pil_image(seg))
msg.append(await to_image_data(seg))

if not msg:
await chat.finish("未获取到有效输入,输入应为文本或图片")

try:
resp = await models[model].generate_content_async(msg)
resp = await gemini.generate(msg)
except Exception as e:
await chat.finish(f"{type(e).__name__}: {e}")

try:
result = resp.text
except ValueError:
result = "\n---\n".join(
[part.text for part in resp.candidates[0].content.parts]
)
result = resp["candidates"][0]["content"]["parts"][0]["text"]
except KeyError:
result = "未获取到有效回复"

await chat.finish(
await UniMessage(Image(raw=await to_markdown(result))).export()
if len(result) > 500
if len(result) > plugin_config.image_render_length
else result.strip()
)

Expand All @@ -126,30 +109,28 @@ async def start_conversation(
if args.extract_plain_text() != "":
matcher.set_arg(key="msg", message=args)

state["gemini_chat_session"] = models["gemini-pro"].start_chat(history=[])
state["gemini_chat_session"] = GeminiChatSession(GOOGLE_API_KEY, plugin_config.proxy)


@conversation.got("msg", prompt="对话开始")
async def got_message(state: T_State, msg: str = ArgPlainText()):
if msg in ["结束", "结束对话", "结束会话", "stop", "quit"]:
await conversation.finish("对话结束")

chat_session: ChatSession = state["gemini_chat_session"]
chat_session: GeminiChatSession = state["gemini_chat_session"]

try:
resp = await chat_session.send_message_async(msg)
resp = await chat_session.send_message(msg)
except Exception as e:
await conversation.finish(f"发生意外错误,对话已结束\n{type(e).__name__}: {e}")

try:
result = resp.text
except ValueError:
result = "\n---\n".join(
[part.text for part in resp.candidates[0].content.parts]
)
result = resp["candidates"][0]["content"]["parts"][0]["text"]
except KeyError:
result = "未获取到有效回复"

await conversation.reject(
await UniMessage(Image(raw=await to_markdown(result))).export()
if len(result) > 500
if len(result) > plugin_config.image_render_length
else result.strip()
)
2 changes: 2 additions & 0 deletions nonebot_plugin_gemini/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@

class Config(BaseModel):
google_api_key: Optional[str] = None
proxy: Optional[str] = None
image_render_length: Optional[int] = 500
85 changes: 85 additions & 0 deletions nonebot_plugin_gemini/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import fleep
import base64
import aiohttp
from io import BytesIO
from typing import List, Union

from .model import Response as GeminiResponse


class Gemini:
def __init__(self, google_api_key: str, proxy: str = None):
self._proxy = proxy
self._google_api_key = google_api_key

async def generate(
self,
contents: Union[List[Union[str, bytes, BytesIO]], str] = "",
*,
_contents: list = None,
) -> GeminiResponse:
model = "gemini-pro"

if isinstance(contents, str):
parts = [{"text": contents}]
elif isinstance(contents, list):
parts = []
for content in contents:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, (bytes, BytesIO)):
model = "gemini-pro-vision"
info = fleep.get(content[:128])
parts.append(
{
"inline_data": {
"mime_type": info.mime[0],
"data": self._to_b64(content),
}
}
)
else:
raise ValueError("Unsupported content type")
else:
raise ValueError("Unsupported contents type")

async with aiohttp.ClientSession() as session:
async with session.post(
f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={self._google_api_key}",
json={
"contents": [
{
"parts": parts,
}
]
if _contents is None
else _contents # 给 GeminiChatSession 用的
},
proxy=self._proxy,
) as resp:
if resp.status != 200:
raise Exception(f'Status code: {resp.status}, message: {(await resp.json())["error"]["message"]}')

data: GeminiResponse = await resp.json()
return data

def _to_b64(self, content: Union[bytes, BytesIO]) -> str:
if isinstance(content, bytes):
return base64.b64encode(content).decode()
elif isinstance(content, BytesIO):
return base64.b64encode(content.getvalue()).decode()
else:
raise ValueError("Unsupported content type")


class GeminiChatSession(Gemini):
def __init__(self, google_api_key: str, proxy: str = None):
self.history = []

super().__init__(google_api_key, proxy)

async def send_message(self, message: str) -> GeminiResponse:
self.history.append({"role": "user", "parts": [{"text": message}]})
resp = await self.generate(_contents=self.history)
self.history.append(resp["candidates"][0]["content"])
return resp
31 changes: 31 additions & 0 deletions nonebot_plugin_gemini/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, TypedDict


class Part(TypedDict):
text: str


class Content(TypedDict):
parts: List[Part]
role: str


class SafetyRating(TypedDict):
category: str
probability: str


class Candidate(TypedDict):
content: Content
finishReason: str
index: int
safetyRatings: List[SafetyRating]


class PromptFeedback(TypedDict):
safetyRatings: List[SafetyRating]


class Response(TypedDict):
candidates: List[Candidate]
promptFeedback: PromptFeedback
Loading

0 comments on commit 0071cc1

Please sign in to comment.