Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve compound query generation #63

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions api/services/chat.service.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import {
ChatSession,
CountTokensRequest,
EnhancedGenerateContentResponse,
GenerateContentResult,
Part,
} from "@google/generative-ai";
import { ChatSession, CountTokensRequest, EnhancedGenerateContentResponse, Part } from "@google/generative-ai";
import { oneLine, stripIndents } from "common-tags";
import { AiModels } from "../lib/constants";
import { GenerativeAIService } from "./ai.service";
Expand Down Expand Up @@ -80,9 +74,7 @@ export class ChatService extends GenerativeAIService {
};
}

displayTokenCount = async (
request: string | (string | Part)[] | CountTokensRequest
) => {
displayTokenCount = async (request: string | (string | Part)[] | CountTokensRequest) => {
const aiModel = AiModels.gemini;
const model = this.generativeModel(aiModel);
const { totalTokens } = await model.countTokens(request);
Expand All @@ -96,9 +88,7 @@ export class ChatService extends GenerativeAIService {
await this.displayTokenCount({ contents: [...history, msgContent] });
};

streamToStdout = async (
stream: AsyncGenerator<EnhancedGenerateContentResponse, any, unknown>
) => {
streamToStdout = async (stream: AsyncGenerator<EnhancedGenerateContentResponse, any, unknown>) => {
console.log("Streaming...\n");
for await (const chunk of stream) {
const chunkText = chunk.text();
Expand Down
70 changes: 44 additions & 26 deletions api/services/embed.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { GenerativeAIService } from "./ai.service";
import { DocumentTypeService } from "./document-type.service";
import { DocumentService } from "./document.service";
import { DomainService } from "./domain.service";
import { match } from "assert";
import { oneLine } from "common-tags";

/**The `role` parameter in the `ContentPart` object is used to specify the role of the text content in relation to the task being performed.
* the following roles are commonly used:
Expand Down Expand Up @@ -97,12 +97,14 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
}
}

/**
* Calculates the cosine similarity between two vectors.
* @param vecA - The first vector.
* @param vecB - The second vector.
* @returns The cosine similarity between the two vectors.
*/
/* Computes the cosine similarity between two vectors of equal length.
* Cosine similarity is a measure of the similarity between two vectors, and is calculated by finding the dot product
* of the two vectors divided by the product of their magnitudes.
* @param vecA - The first vector * @param vecB - The second vector
* @throws {Error} if the vectors are not of equal length
* @returns {number} - A number between -1 and 1 representing the cosine similarity of the two vectors
*
* */
cosineSimilarity(vecA: number[], vecB: number[]): number {
let consineDistance = 0;
let dotProduct = 0;
Expand Down Expand Up @@ -210,16 +212,35 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
return textEmbeddings;
}

/**
* Generates 2 similar queries and appends the original query to the generated queries.
* @param query - The original query
* @returns A promise that resolves to a string of the generated queries
* */
async generateSimilarQueries(query: string): Promise<string> {
const model = AiModels.gemini;
const aiModel: GenerativeModel = this.generativeModel(model);
const prompt = `Generate 2 additional comma seperated queries that are similar to this query and append the original query too: ${query}`;
const prompt = oneLine`
when asked a compound question that contains multiple parts,
I want you to break it down into separate sub-queries that can be answered individually,
the query should be broken down to at most 3 parts, return comma seperated queries.
However if the question is a single question, straight forward query without multiple parts,
Generate 2 additional comma seperated queries that are similar to this query and append the original query too: ${query}
`;
const result: GenerateContentResult = await aiModel.generateContent(prompt);
const response: EnhancedGenerateContentResponse = result.response;
const text: string = response.text();
console.log(text);
return text;
}

/**
* Generates query embeddings for retrieval task
* Generates similar queries and then generates embeddings for each query
* @param query - The query to generate embeddings for
* @returns A Promise that resolves to a 2D array of embeddings
* @throws {HttpException} if unable to generate similar queries
**/
async generateUserQueryEmbeddings(query: string): Promise<number[][]> {
const queries = await this.generateSimilarQueries(query);
if (!queries?.length) {
Expand All @@ -233,30 +254,27 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
return embeddings.map((e) => e.embedding);
}

/**
* Generates query matches for the given user query, match count, and similarity threshold.
* 1. Generate embeddings for the user query.
* 2. Match documents to the query embeddings.
* 3. Flattens the resulting matches.
* @param query - The user query to match against.
* @param matchCount - The number of matches to return per embedding.
* @param similarityThreshold - The minimum similarity score to consider a match.
* @returns An array of query matches.
* @throws {HttpException} if query embeddings could not be generated.
**/
async getQueryMatches(query: string, matchCount: number, similarityThreshold: number): Promise<IQueryMatch[]> {
const queryEmbeddings = await this.generateUserQueryEmbeddings(query);
if (!queryEmbeddings?.length) {
throw new HttpException(HTTP_RESPONSE_CODE.BAD_REQUEST, "Unable to generate user query embeddings");
}
const embeddingRepository: EmbeddingRepository = new EmbeddingRepository();
const [firstEmbeddings, secondEmbeddings, thirdEmbeddings] = queryEmbeddings;
//Check if this works with map and promise.all
const originalQuery: IQueryMatch[] = await embeddingRepository.matchDocuments(
firstEmbeddings,
matchCount,
similarityThreshold
);
const intialAiGenratedQuery: IQueryMatch[] = await embeddingRepository.matchDocuments(
secondEmbeddings,
matchCount,
similarityThreshold
);
const otherAiGenratedQuery: IQueryMatch[] = await embeddingRepository.matchDocuments(
thirdEmbeddings,
matchCount,
similarityThreshold
const embeddings = queryEmbeddings.map((embedding) =>
embeddingRepository.matchDocuments(embedding, matchCount, similarityThreshold)
);
const matches: IQueryMatch[] = [...originalQuery, ...intialAiGenratedQuery, ...otherAiGenratedQuery];
return matches;
const matches = await Promise.all(embeddings);
return matches.flat();
}
}
35 changes: 7 additions & 28 deletions presentation/src/components/ChatForm.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import DOMPurify from "dompurify";
import { useState } from "react";
import {
Button,
Card,
Col,
Container,
Form,
Row,
Stack,
} from "react-bootstrap";
import { Button, Card, Col, Container, Form, Row, Stack } from "react-bootstrap";
import useAxiosPrivate from "../hooks/useAxiosPrivate";
import { formatCodeBlocks, formatText } from "../utils";
import NavBar from "./NavBar";
Expand Down Expand Up @@ -93,11 +85,7 @@ export function Thread() {
Send
</Button>
<div className="vr" />
<Button
variant="outline-danger"
onClick={clearChat}
disabled={loading}
>
<Button variant="outline-danger" onClick={clearChat} disabled={loading}>
Reset
</Button>
</Stack>
Expand All @@ -111,15 +99,15 @@ export function Thread() {
{loading ? (
<>
<div
className="loading-skeleton"
className="loader"
style={{
marginBottom: "10px",
marginTop: "10px",
height: "70px",
}}
></div>
<div
className="loading-skeleton"
className="loader"
style={{
marginBottom: "10px",
marginTop: "10px",
Expand All @@ -133,22 +121,13 @@ export function Thread() {
</div>
<div>
{chatHistory.map((chatItem, index) => (
<Card
style={{ marginBottom: "10px", marginTop: "10px" }}
key={index}
>
<Card.Header>
{chatItem.role && chatItem.role === "user"
? "Question"
: "Answer"}
</Card.Header>
<Card style={{ marginBottom: "10px", marginTop: "10px" }} key={index}>
<Card.Header>{chatItem.role && chatItem.role === "user" ? "Question" : "Answer"}</Card.Header>
{chatItem.parts.map((part, i) => (
<Card.Body key={i}>
<Card.Text
dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(
formatCodeBlocks(formatText(part.text))
),
__html: DOMPurify.sanitize(formatCodeBlocks(formatText(part.text))),
}}
></Card.Text>
</Card.Body>
Expand Down
2 changes: 1 addition & 1 deletion presentation/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const BASE_URL = "http://localhost:4000";
export const BASE_URL = "http://localhost:3000";
export enum CHAT_PARAMS {
MATCH_COUNT = 3,
SIMILARITY_THRESHOLD = 0.7,
Expand Down
78 changes: 59 additions & 19 deletions presentation/src/index.css
Original file line number Diff line number Diff line change
@@ -1,25 +1,65 @@
.loading-skeleton {
background-color: #f0f0f0;
border-radius: 4px;
animation: loading 1s infinite ease-in-out;
background-color: #f0f0f0;
border-radius: 4px;
animation: loading 1s infinite ease-in-out;
}

@keyframes loading {
0% {
opacity: 0.5;
}

@keyframes loading {
0% { opacity: 0.5; }
50% { opacity: 1; }
100% { opacity: 0.5; }

50% {
opacity: 1;
}

.straight-line {
position: fixed;
top: 0;
left: 50%;
width: 1px;
height: 100vh;
background-color: black;
z-index: 999;
100% {
opacity: 0.5;
}
}

.straight-line {
position: fixed;
top: 0;
left: 50%;
width: 1px;
height: 100vh;
background-color: black;
z-index: 999;
}

body {
background-color: #f8f9fa;
}

body{
background-color: #f8f9fa;
}
.loader {
position: relative;
background-color: rgb(235, 235, 235);
max-width: 100%;
height: auto;
background: #efefee;
overflow: hidden;
border-radius: 4px;
margin-bottom: 4px;
}

.loader::after {
display: block;
content: "";
position: absolute;
width: 100%;
height: 100%;
transform: translateX(-100%);
background: linear-gradient(90deg, transparent, #f1f1f1, transparent);
background: linear-gradient(90deg,
transparent,
rgba(255, 255, 255, 0.4),
transparent);
animation: loading 1s infinite;
}

@keyframes loading {
100% {
transform: translateX(100%);
}
}
Loading