Skip to content

Commit

Permalink
add the feature to be able to select document
Browse files Browse the repository at this point in the history
  • Loading branch information
Olasunkanmi Oyinlola authored and Olasunkanmi Oyinlola committed Apr 19, 2024
1 parent bc6d0a5 commit 8d257bb
Show file tree
Hide file tree
Showing 21 changed files with 246 additions and 81 deletions.
19 changes: 5 additions & 14 deletions api/controllers/chat-controller.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import * as express from "express";
import { ChatHandler } from "../handlers/chat.handler";
import {
chatHistorySchema,
chatRequestSchema,
} from "../lib/validation-schemas";
import { chatHistorySchema, chatRequestSchema } from "../lib/validation-schemas";
import { Result } from "../lib/result";
import { generateErrorResponse } from "../utils/utils";
export class ChatController {
Expand All @@ -17,23 +14,17 @@ export class ChatController {
this.router.post(`${this.path}`, this.chat);
}

async chat(
req: express.Request,
res: express.Response,
next: express.NextFunction
) {
async chat(req: express.Request, res: express.Response, next: express.NextFunction) {
try {
const { question, chatHistory } = chatRequestSchema.parse(req.body);
const { question, chatHistory, documentId } = chatRequestSchema.parse(req.body);
const chatHandler = new ChatHandler();
const history = chatHistorySchema.parse(JSON.parse(chatHistory));
const data = await chatHandler.handle({ question, chatHistory: history });
const data = await chatHandler.handle({ question, chatHistory: history, documentId });
if (data) {
const result = Result.ok(data.getValue());
res.status(200).json(result);
} else {
res
.status(400)
.json(Result.fail("Unable to create document type", 400));
res.status(400).json(Result.fail("Unable to generate model response", 400));
}
} catch (error) {
generateErrorResponse(error, res, next);
Expand Down
31 changes: 31 additions & 0 deletions api/controllers/document-controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import * as express from "express";
import { Result } from "../lib/result";
import { generateErrorResponse } from "../utils/utils";
import { GetDocumentsHandler } from "../handlers/get-documents-handler";

export class DocumentController {
path = "/documents";
router = express.Router();

constructor() {
this.initializeRoute();
}

initializeRoute() {
this.router.get(this.path, this.getDocument);
}

async getDocument(req: express.Request, res: any, next: express.NextFunction) {
try {
const documentHandler = new GetDocumentsHandler();
const data = await documentHandler.handle();
if (data) {
const result = Result.ok(data.getValue());
res.status(200).json(result);
}
} catch (error) {
generateErrorResponse(error, res, next);
next(error);
}
}
}
21 changes: 5 additions & 16 deletions api/handlers/chat.handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,14 @@ import { EmbeddingService } from "../services/embed.service";
import { getValue } from "../utils";
import { CHAT_PARAMS } from "./../../presentation/src/constants";

export class ChatHandler
implements
IRequestHandler<IChatRequestDTO, Result<Partial<IChatResponseDTO>>>
{
export class ChatHandler implements IRequestHandler<IChatRequestDTO, Result<Partial<IChatResponseDTO>>> {
private readonly apiKey: string = getValue("API_KEY");
async handle({
question,
chatHistory,
}: IChatRequestDTO): Promise<Result<Partial<IChatResponseDTO>>> {
async handle({ question, chatHistory, documentId }: IChatRequestDTO): Promise<Result<Partial<IChatResponseDTO>>> {
try {
const embeddingService: EmbeddingService = new EmbeddingService(
this.apiKey
);
const embeddingService: EmbeddingService = new EmbeddingService(this.apiKey);
const { MATCH_COUNT, SIMILARITY_THRESHOLD } = CHAT_PARAMS;
const matches = await embeddingService.getQueryMatches(
question,
MATCH_COUNT,
SIMILARITY_THRESHOLD
);
//Query here
const matches = await embeddingService.getQueryMatches(question, MATCH_COUNT, SIMILARITY_THRESHOLD, documentId);
// if (!matches?.length) {
// //take care of empty results here
// return "No matches for user query";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { DocumentRepository } from "../repositories/document.repository";
import { ICreateDocumentDTO } from "../repositories/dtos/dtos";
import { IDocumentModel } from "../repositories/model";

export class DocumentHandler implements IRequestHandler<ICreateDocumentDTO, Result<IDocumentModel>> {
export class CreateDocumentHandler implements IRequestHandler<ICreateDocumentDTO, Result<IDocumentModel>> {
async handle(request: ICreateDocumentDTO): Promise<Result<IDocumentModel>> {
try {
let response: IDocumentModel | undefined;
Expand Down
17 changes: 17 additions & 0 deletions api/handlers/get-documents-handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { IRequestHandler } from "../interfaces/handler";
import { Result } from "../lib/result";
import { DocumentRepository } from "../repositories/document.repository";
import { IDocumentModel } from "../repositories/model";

export class GetDocumentsHandler implements IRequestHandler<{}, Result<IDocumentModel[]>> {
async handle(): Promise<Result<IDocumentModel[]>> {
try {
let response: IDocumentModel[];
const documentRespository: DocumentRepository = new DocumentRepository();
response = await documentRespository.getDocuments();
return Result.ok(response);
} catch (error) {
console.error(error);
}
}
}
9 changes: 8 additions & 1 deletion api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@ import { ChatController } from "./controllers/chat-controller";
import { DocmentTypeController } from "./controllers/document-type.controller";
import { DomainController } from "./controllers/domain.controller";
import { EmbeddingController } from "./controllers/embed.controller";
import { DocumentController } from "./controllers/document-controller";

const port: number = Number(process.env.PORT) || 3000;
const app = new App(
[new EmbeddingController(), new DomainController(), new DocmentTypeController(), new ChatController()],
[
new EmbeddingController(),
new DomainController(),
new DocmentTypeController(),
new ChatController(),
new DocumentController(),
],
port
);
app.listen();
5 changes: 3 additions & 2 deletions api/interfaces/embedding-service.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { IQueryMatch } from "./generic-interface";
export interface IEmbeddingService {
generateEmbeddings(
taskType: TaskType,
role?: string,
role?: string
): Promise<{
embedding: number[];
text: string;
Expand All @@ -16,11 +16,12 @@ export interface IEmbeddingService {
createDocumentsEmbeddings(
title: string,
documentType: DocumentTypeEnum,
domain: DomainEnum,
domain: DomainEnum
): Promise<Result<boolean>>;
getQueryMatches(
query: string,
matchCount: number,
similarityThreshold: number,
documentId: number
): Promise<IQueryMatch[]>;
}
1 change: 1 addition & 0 deletions api/interfaces/handler.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Todo, refactor IRequestHandler< TResponse, TRequest = any> so as to make TRequest optional type
export interface IRequestHandler<TRequest, TResponse> {
handle(request?: TRequest): Promise<TResponse>;
}
3 changes: 3 additions & 0 deletions api/lib/validation-schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ export const domainRequestSchema = z.object({ name });
const docType = z.nativeEnum(DocumentTypeEnum);
export const docTypeRequestSchema = z.object({ name: docType });

export const createDocumentSchema = z.object({ title: z.string() });

export const chatRequestSchema = z.object({
documentId: z.number(),
question: z.string(),
metaData: z.optional(
z.object({
Expand Down
13 changes: 9 additions & 4 deletions api/repositories/document.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ export class DocumentRepository extends Database {
//In your instructiona prompt, ask the AI to genrate code if any is available
const docExists: IDocumentModel = await this.findOne(title);
if (docExists) {
throw new HttpException(
HTTP_RESPONSE_CODE.BAD_REQUEST,
"document already exists",
);
throw new HttpException(HTTP_RESPONSE_CODE.BAD_REQUEST, "document already exists");
}
return await this.prisma.documents.create({
data: {
Expand All @@ -44,6 +41,14 @@ export class DocumentRepository extends Database {
}
}

async getDocuments(): Promise<IDocumentModel[]> {
try {
return await this.prisma.documents.findMany();
} catch (error) {
console.error(error);
}
}

async insertMany(): Promise<Prisma.PrismaPromise<{ count: number }>> {
try {
const result = await this.prisma.documents.createMany();
Expand Down
1 change: 1 addition & 0 deletions api/repositories/dtos/dtos.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export interface ICreateDocumentTypeRequestDTO {
}

export interface IChatRequestDTO {
documentId: number;
question: string;
metaData?: {
documentId: number;
Expand Down
13 changes: 11 additions & 2 deletions api/repositories/embedding.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,14 @@ export class EmbeddingRepository extends Database {
/**
* Queries the database for listings that are similar to a given embedding.
*/
async matchDocuments(embedding: any, matchCount: number, matchThreshold: number): Promise<IQueryMatch[]> {
//Raw query failed. Code: `42601`. Message: `ERROR: syntax error at or near "WHERE"`
async matchDocuments(
embedding: any,
matchCount: number,
matchThreshold: number,
documentId: number
): Promise<IQueryMatch[]> {
console.log({ documentId });
//change text to document_embedding
//check how to select textembedding from DB
const matches = await this.prisma.$queryRaw`
Expand All @@ -141,7 +148,9 @@ export class EmbeddingRepository extends Database {
1 - ("textEmbedding" <=> ${embedding}::vector) as similarity
FROM
"Embeddings"
WHERE
WHERE
"documentId" = ${documentId}
AND
1 - ("textEmbedding" <=> ${embedding}::vector) > ${matchThreshold}
ORDER BY
similarity DESC
Expand Down
9 changes: 7 additions & 2 deletions api/services/embed.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,20 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
* @returns An array of query matches.
* @throws {HttpException} if query embeddings could not be generated.
**/
async getQueryMatches(query: string, matchCount: number, similarityThreshold: number): Promise<IQueryMatch[]> {
async getQueryMatches(
query: string,
matchCount: number,
similarityThreshold: number,
documentId: number
): Promise<IQueryMatch[]> {
const queryEmbeddings = await this.generateUserQueryEmbeddings(query);
if (!queryEmbeddings?.length) {
throw new HttpException(HTTP_RESPONSE_CODE.BAD_REQUEST, "Unable to generate user query embeddings");
}
const embeddingRepository: EmbeddingRepository = new EmbeddingRepository();
const embeddings = queryEmbeddings.map((embedding) =>
//passing in the documentId here.
embeddingRepository.matchDocuments(embedding, matchCount, similarityThreshold)
embeddingRepository.matchDocuments(embedding, matchCount, similarityThreshold, documentId)
);
const matches = await Promise.all(embeddings);
return matches.flat();
Expand Down
8 changes: 8 additions & 0 deletions presentation/src/ErrorFallBackComponent.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export function ErrorFallBackComponent() {
return (
<div>
<h1>Oops! Something went wrong.</h1>
<p>Please try again later.</p>
</div>
);
}
26 changes: 21 additions & 5 deletions presentation/src/components/ChatForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import useAxiosPrivate from "../hooks/useAxiosPrivate";
import NavBar from "./NavBar";
import markdownIt from "markdown-it";
import Books from "./DropDown";
import { IDocument } from "../interfaces/document.interface";

interface IHistory {
role: string;
Expand All @@ -17,6 +18,11 @@ export function Thread() {
const [question, setQuestion] = useState("");
const [chatHistory, setChatHistory] = useState<IHistory[]>([]);
const [loading, setLoading] = useState(false);
const [selectedBook, setSelectedBook] = useState<IDocument>();

const handleBookSelect = (bookData: IDocument) => {
setSelectedBook(bookData);
};

const formAction = async () => {
if (!question) {
Expand All @@ -28,6 +34,7 @@ export function Thread() {
setQuestion("");
console.log(chatHistory);
const response = await axiosPrivate.post("/chat", {
documentId: selectedBook?.id,
question,
chatHistory: JSON.stringify(chatHistory.slice(0, 4)),
});
Expand Down Expand Up @@ -76,7 +83,7 @@ export function Thread() {
<div style={{ marginTop: "20px" }}>
<Stack direction="horizontal" gap={3}>
<div className="p-2">
<Books />
<Books onBookSelect={handleBookSelect} />
</div>
<div className="p-2">
<Form onSubmit={handleSubmit}>
Expand Down Expand Up @@ -113,15 +120,15 @@ export function Thread() {
style={{
marginBottom: "10px",
marginTop: "10px",
height: "70px",
height: "30px",
}}
></div>
<div
className="loader"
style={{
marginBottom: "10px",
marginTop: "10px",
height: "140px",
height: "70px",
}}
></div>
</>
Expand All @@ -131,8 +138,17 @@ export function Thread() {
</div>
<div>
{chatHistory.map((chatItem, index) => (
<Card style={{ marginBottom: "10px", marginTop: "10px" }} key={index}>
<Card.Header>{chatItem.role && chatItem.role === "user" ? "Question" : "Answer"}</Card.Header>
<Card
style={{
marginBottom: "10px",
marginTop: "10px",
backgroundColor: "#212529",
color: "#fff",
borderColor: `${chatItem.role && chatItem.role === "user" ? "" : "#2c2c29"}`,
borderWidth: "medium",
}}
key={index}
>
{chatItem.parts.map((part, i) => (
<Card.Body key={i}>
<Card.Text
Expand Down
Loading

0 comments on commit 8d257bb

Please sign in to comment.