Skip to content

Commit

Permalink
Merge pull request #58 from olasunkanmi-SE/test
Browse files Browse the repository at this point in the history
establish chat with LLM
  • Loading branch information
olasunkanmi-SE authored Mar 25, 2024
2 parents 0f90f30 + 6bafdf7 commit 5860273
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 48 deletions.
20 changes: 15 additions & 5 deletions api/controllers/chat-controller.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import * as express from "express";
import { ChatHandler } from "../handlers/chat.handler";
import { chatRequestSchema } from "../lib/validation-schemas";
import {
chatHistorySchema,
chatRequestSchema,
} from "../lib/validation-schemas";
import { Result } from "../lib/result";
import { generateErrorResponse } from "../utils/utils";
export class ChatController {
Expand All @@ -14,16 +17,23 @@ export class ChatController {
this.router.post(`${this.path}`, this.chat);
}

async chat(req: express.Request, res: express.Response, next: express.NextFunction) {
async chat(
req: express.Request,
res: express.Response,
next: express.NextFunction
) {
try {
const { question } = chatRequestSchema.parse(req.body);
const { question, chatHistory } = chatRequestSchema.parse(req.body);
const chatHandler = new ChatHandler();
const data = await chatHandler.handle({ question });
const history = chatHistorySchema.parse(JSON.parse(chatHistory));
const data = await chatHandler.handle({ question, chatHistory: history });
if (data) {
const result = Result.ok(data.getValue());
res.status(200).json(result);
} else {
res.status(400).json(Result.fail("Unable to create document type", 400));
res
.status(400)
.json(Result.fail("Unable to create document type", 400));
}
} catch (error) {
generateErrorResponse(error, res, next);
Expand Down
25 changes: 20 additions & 5 deletions api/handlers/chat.handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,35 @@ import { EmbeddingService } from "../services/embed.service";
import { getValue } from "../utils";
import { CHAT_PARAMS } from "./../../presentation/src/constants";

export class ChatHandler implements IRequestHandler<IChatRequestDTO, Result<IChatResponseDTO>> {
export class ChatHandler
implements IRequestHandler<IChatRequestDTO, Result<IChatResponseDTO>>
{
private readonly apiKey: string = getValue("API_KEY");
async handle({ question }: IChatRequestDTO): Promise<Result<IChatResponseDTO>> {
async handle({
question,
chatHistory,
}: IChatRequestDTO): Promise<Result<IChatResponseDTO>> {
try {
const embeddingService: EmbeddingService = new EmbeddingService(this.apiKey);
const embeddingService: EmbeddingService = new EmbeddingService(
this.apiKey
);
const { MATCH_COUNT, SIMILARITY_THRESHOLD } = CHAT_PARAMS;
const matches = await embeddingService.getQueryMatches(question, MATCH_COUNT, SIMILARITY_THRESHOLD);
const matches = await embeddingService.getQueryMatches(
question,
MATCH_COUNT,
SIMILARITY_THRESHOLD
);
// if (!matches?.length) {
// //take care of empty results here
// return "No matches for user query";
// }
const context: string = matches.map((match) => match.context).join(" ,");
const questions: string[] = [question];
const chatService: ChatService = new ChatService(this.apiKey, { context, questions });
const chatService: ChatService = new ChatService(this.apiKey, {
context,
questions,
chatHistory,
});
const response = await chatService.run();
return Result.ok(response);
} catch (error) {
Expand Down
16 changes: 12 additions & 4 deletions api/handlers/create-document-embed.handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@ import { IRequestHandler } from "../interfaces/handler";
import { HTTP_RESPONSE_CODE } from "../lib/constants";
import { Result } from "../lib/result";
import { ICreateEmbeddingRequestDTO } from "../repositories/dtos/dtos";
import { ChatService } from "../services/chat.service";
import { EmbeddingService } from "../services/embed.service";
import { getValue } from "../utils";

export class CreateDocumentEmbeddingHandler implements IRequestHandler<ICreateEmbeddingRequestDTO, Result<boolean>> {
export class CreateDocumentEmbeddingHandler
implements IRequestHandler<ICreateEmbeddingRequestDTO, Result<boolean>>
{
private readonly apiKey: string = getValue("API_KEY");
embeddingService: EmbeddingService = new EmbeddingService(this.apiKey);
async handle(request: ICreateEmbeddingRequestDTO): Promise<Result<boolean>> {
try {
const { title, documentType, domain } = request;
const result = await this.embeddingService.createDocumentsEmbeddings(title, documentType, domain);
const result = await this.embeddingService.createDocumentsEmbeddings(
title,
documentType,
domain
);
if (!result) {
throw new HttpException(HTTP_RESPONSE_CODE.BAD_REQUEST, "An error occured, could not create embeddings");
throw new HttpException(
HTTP_RESPONSE_CODE.BAD_REQUEST,
"An error occured, could not create embeddings"
);
}
return result;
} catch (error) {
Expand Down
19 changes: 12 additions & 7 deletions api/lib/validation-schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@ export const chatRequestSchema = z.object({
pageNumber: z.number(),
})
),
chatHistory: z
.object({
role: z.string(),
parts: z.object({ text: z.string() }).array(),
})
.optional()
.array(),
chatHistory: z.string(),
});

export const chatHistorySchema = z.array(
z.object({
role: z.string(),
parts: z.array(
z.object({
text: z.string(),
})
),
})
);
6 changes: 6 additions & 0 deletions api/repositories/dtos/dtos.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ export interface IChatRequestDTO {
documentId: number;
pageNumber: number;
};
chatHistory?: IHistory[];
}

export interface IHistory {
role?: string;
parts?: { text?: string }[];
}

export interface IChatResponseDTO {
Expand Down
44 changes: 29 additions & 15 deletions api/services/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ import {
import { oneLine, stripIndents } from "common-tags";
import { AiModels } from "../lib/constants";
import { GenerativeAIService } from "./ai.service";
import { IChatResponseDTO } from "../repositories/dtos/dtos";
import { IChatResponseDTO, IHistory } from "../repositories/dtos/dtos";

export class ChatService extends GenerativeAIService {
initialConvo: any;
constructor(
apiKey: string,
private readonly conversation: { context: string; questions: string[] }
private readonly conversation: {
context: string;
questions: string[];
chatHistory: IHistory[];
}
) {
super(apiKey);
this.initChat();
Expand All @@ -29,36 +33,42 @@ export class ChatService extends GenerativeAIService {
parts: [
{
text: stripIndents`${oneLine`
Using the information contained in the context,
give a comprehensive answer to the question.
Respond only to the question asked, response should be concise and relevant to the question.
You should imply how the mybid project intend to solve issues relating to the question if you can find
any within the given context.
If the answer cannot be deduced from the context, do not give an answer.
context: ${conversation.context}
Give a coincise answer
Examine the given context for any problems or challenges mentioned.
Consider how the MyBid project could potentially address or solve these issues based on the context provided.
If it's possible to deduce how MyBid intends to solve the issues, provide that information. If not, respond with "I don't know".
Avoid External Sources: Do not search for information outside of the given context to formulate your response.
If you cannot find any relevent information in relating to the Question, just answer I am sorry I dont know.
Here is the context: ${conversation.context}
`}`,
},
],
},
{
role: "model",
parts: [{ text: "Great to meet you. What would you like to know about Mybid?" }],
parts: [
{
text: conversation.questions[0],
},
],
},
...this.conversation.chatHistory,
],
// generationConfig: {
// maxOutputTokens: 200,
// },
};
const aiModel = AiModels.gemini;
const model = await this.generativeModel(aiModel);
return await model.startChat(this.initialConvo);
const model = this.generativeModel(aiModel);
return model.startChat(this.initialConvo);
};

async run(): Promise<IChatResponseDTO> {
const question = `${this.conversation.questions[0]}`;
this.displayChatTokenCount(question);
const chat: ChatSession = await this.initChat();
const result: GenerateContentResult = await chat.sendMessage(question);
const response: EnhancedGenerateContentResponse = await result.response;
const response: EnhancedGenerateContentResponse = result.response;
const answer = response.text();
const chatHistory = JSON.stringify(await chat.getHistory(), null, 2);
return {
Expand All @@ -68,7 +78,9 @@ export class ChatService extends GenerativeAIService {
};
}

displayTokenCount = async (request: string | (string | Part)[] | CountTokensRequest) => {
displayTokenCount = async (
request: string | (string | Part)[] | CountTokensRequest
) => {
const aiModel = AiModels.gemini;
const model = this.generativeModel(aiModel);
const { totalTokens } = await model.countTokens(request);
Expand All @@ -82,7 +94,9 @@ export class ChatService extends GenerativeAIService {
await this.displayTokenCount({ contents: [...history, msgContent] });
};

streamToStdout = async (stream: AsyncGenerator<EnhancedGenerateContentResponse, any, unknown>) => {
streamToStdout = async (
stream: AsyncGenerator<EnhancedGenerateContentResponse, any, unknown>
) => {
console.log("Streaming...\n");
for await (const chunk of stream) {
const chunkText = chunk.text();
Expand Down
32 changes: 21 additions & 11 deletions presentation/src/components/ChatForm.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import { useState } from "react";
import { Button, Card, Col, Container, Form, Row, Stack } from "react-bootstrap";
import {
Button,
Card,
Col,
Container,
Form,
Row,
Stack,
} from "react-bootstrap";
import useAxiosPrivate from "../hooks/useAxiosPrivate";
import DOMPurify from "dompurify";
import { formatText } from "../utils";

// interface Message {
// text: string;
Expand Down Expand Up @@ -34,24 +43,22 @@ export function Thread() {
return;
}
try {
const body = {
const response = await axiosPrivate.post("/chat", {
question,
history: chatHistory,
};

const response = await axiosPrivate.post("/domain/create", JSON.stringify(body));
chatHistory: JSON.stringify(chatHistory),
});
const data = response.data;
console.log(data);
console.log(JSON.stringify(data.data.chatHistory));
setChatHistory((oldChat) => [
...oldChat,
{
role: "user",
parts: [{ text: question }],
},
{
role: "model",
parts: [{ text: data }],
parts: [{ text: data.data.answer }],
},
...oldChat,
]);
setQuestion("");
return data;
Expand Down Expand Up @@ -104,13 +111,16 @@ export function Thread() {
<p>{error}</p>
</div>
{chatHistory.map((chatItem, index) => (
<Card style={{ marginBottom: "10px", marginTop: "10px" }} key={index}>
<Card
style={{ marginBottom: "10px", marginTop: "10px" }}
key={index}
>
<Card.Header>{chatItem.role}</Card.Header>
{chatItem.parts.map((part, i) => (
<Card.Body key={i}>
<Card.Text
dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(part.text),
__html: DOMPurify.sanitize(formatText(part.text)),
}}
></Card.Text>
</Card.Body>
Expand Down
2 changes: 1 addition & 1 deletion presentation/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const BASE_URL = "http://localhost:3000";
export const BASE_URL = "http://localhost:4000";
export enum CHAT_PARAMS {
MATCH_COUNT = 3,
SIMILARITY_THRESHOLD = 0.7,
Expand Down
17 changes: 17 additions & 0 deletions presentation/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export const formatText = (text: string) => {
const paragraphs = text.split("**");
let formattedText = "";

for (let i = 0; i < paragraphs.length; i++) {
const paragraph = paragraphs[i].trim();
if (i % 2 === 0) {
formattedText += `<p>${paragraph}</p>`;
} else {
const startIndex = paragraph.indexOf(" ");
const boldText = paragraph.substring(0, startIndex);
const restOfParagraph = paragraph.substring(startIndex).trim();
formattedText += `<p><b>${boldText}</b> ${restOfParagraph}</p>`;
}
}
return formattedText.replace(/\*/g, "");
};

0 comments on commit 5860273

Please sign in to comment.