Skip to content

Commit

Permalink
improve compund query generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Olasunkanmi Oyinlola authored and Olasunkanmi Oyinlola committed Apr 11, 2024
1 parent fc3ddf8 commit 36ec297
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 87 deletions.
16 changes: 3 additions & 13 deletions api/services/chat.service.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import {
ChatSession,
CountTokensRequest,
EnhancedGenerateContentResponse,
GenerateContentResult,
Part,
} from "@google/generative-ai";
import { ChatSession, CountTokensRequest, EnhancedGenerateContentResponse, Part } from "@google/generative-ai";
import { oneLine, stripIndents } from "common-tags";
import { AiModels } from "../lib/constants";
import { GenerativeAIService } from "./ai.service";
Expand Down Expand Up @@ -80,9 +74,7 @@ export class ChatService extends GenerativeAIService {
};
}

displayTokenCount = async (
request: string | (string | Part)[] | CountTokensRequest
) => {
displayTokenCount = async (request: string | (string | Part)[] | CountTokensRequest) => {
const aiModel = AiModels.gemini;
const model = this.generativeModel(aiModel);
const { totalTokens } = await model.countTokens(request);
Expand All @@ -96,9 +88,7 @@ export class ChatService extends GenerativeAIService {
await this.displayTokenCount({ contents: [...history, msgContent] });
};

streamToStdout = async (
stream: AsyncGenerator<EnhancedGenerateContentResponse, any, unknown>
) => {
streamToStdout = async (stream: AsyncGenerator<EnhancedGenerateContentResponse, any, unknown>) => {
console.log("Streaming...\n");
for await (const chunk of stream) {
const chunkText = chunk.text();
Expand Down
70 changes: 44 additions & 26 deletions api/services/embed.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { GenerativeAIService } from "./ai.service";
import { DocumentTypeService } from "./document-type.service";
import { DocumentService } from "./document.service";
import { DomainService } from "./domain.service";
import { match } from "assert";
import { oneLine } from "common-tags";

/**The `role` parameter in the `ContentPart` object is used to specify the role of the text content in relation to the task being performed.
* the following roles are commonly used:
Expand Down Expand Up @@ -97,12 +97,14 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
}
}

/**
* Calculates the cosine similarity between two vectors.
* @param vecA - The first vector.
* @param vecB - The second vector.
* @returns The cosine similarity between the two vectors.
*/
/* Computes the cosine similarity between two vectors of equal length.
* Cosine similarity is a measure of the similarity between two vectors, and is calculated by finding the dot product
* of the two vectors divided by the product of their magnitudes.
* @param vecA - The first vector * @param vecB - The second vector
* @throws {Error} if the vectors are not of equal length
* @returns {number} - A number between -1 and 1 representing the cosine similarity of the two vectors
*
* */
cosineSimilarity(vecA: number[], vecB: number[]): number {
let consineDistance = 0;
let dotProduct = 0;
Expand Down Expand Up @@ -210,16 +212,35 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
return textEmbeddings;
}

/**
* Generates 2 similar queries and appends the original query to the generated queries.
* @param query - The original query
* @returns A promise that resolves to a string of the generated queries
* */
async generateSimilarQueries(query: string): Promise<string> {
const model = AiModels.gemini;
const aiModel: GenerativeModel = this.generativeModel(model);
const prompt = `Generate 2 additional comma seperated queries that are similar to this query and append the original query too: ${query}`;
const prompt = oneLine`
when asked a compound question that contains multiple parts,
I want you to break it down into separate sub-queries that can be answered individually,
the query should be broken down to at most 3 parts, return comma seperated queries.
However if the question is a single question, straight forward query without multiple parts,
Generate 2 additional comma seperated queries that are similar to this query and append the original query too: ${query}
`;
const result: GenerateContentResult = await aiModel.generateContent(prompt);
const response: EnhancedGenerateContentResponse = result.response;
const text: string = response.text();
console.log(text);
return text;
}

/**
* Generates query embeddings for retrieval task
* Generates similar queries and then generates embeddings for each query
* @param query - The query to generate embeddings for
* @returns A Promise that resolves to a 2D array of embeddings
* @throws {HttpException} if unable to generate similar queries
**/
async generateUserQueryEmbeddings(query: string): Promise<number[][]> {
const queries = await this.generateSimilarQueries(query);
if (!queries?.length) {
Expand All @@ -233,30 +254,27 @@ export class EmbeddingService extends GenerativeAIService implements IEmbeddingS
return embeddings.map((e) => e.embedding);
}

/**
* Generates query matches for the given user query, match count, and similarity threshold.
* 1. Generate embeddings for the user query.
* 2. Match documents to the query embeddings.
* 3. Flattens the resulting matches.
* @param query - The user query to match against.
* @param matchCount - The number of matches to return per embedding.
* @param similarityThreshold - The minimum similarity score to consider a match.
* @returns An array of query matches.
* @throws {HttpException} if query embeddings could not be generated.
**/
async getQueryMatches(query: string, matchCount: number, similarityThreshold: number): Promise<IQueryMatch[]> {
const queryEmbeddings = await this.generateUserQueryEmbeddings(query);
if (!queryEmbeddings?.length) {
throw new HttpException(HTTP_RESPONSE_CODE.BAD_REQUEST, "Unable to generate user query embeddings");
}
const embeddingRepository: EmbeddingRepository = new EmbeddingRepository();
const [firstEmbeddings, secondEmbeddings, thirdEmbeddings] = queryEmbeddings;
//Check if this works with map and promise.all
const originalQuery: IQueryMatch[] = await embeddingRepository.matchDocuments(
firstEmbeddings,
matchCount,
similarityThreshold
);
const intialAiGenratedQuery: IQueryMatch[] = await embeddingRepository.matchDocuments(
secondEmbeddings,
matchCount,
similarityThreshold
);
const otherAiGenratedQuery: IQueryMatch[] = await embeddingRepository.matchDocuments(
thirdEmbeddings,
matchCount,
similarityThreshold
const embeddings = queryEmbeddings.map((embedding) =>
embeddingRepository.matchDocuments(embedding, matchCount, similarityThreshold)
);
const matches: IQueryMatch[] = [...originalQuery, ...intialAiGenratedQuery, ...otherAiGenratedQuery];
return matches;
const matches = await Promise.all(embeddings);
return matches.flat();
}
}
35 changes: 7 additions & 28 deletions presentation/src/components/ChatForm.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import DOMPurify from "dompurify";
import { useState } from "react";
import {
Button,
Card,
Col,
Container,
Form,
Row,
Stack,
} from "react-bootstrap";
import { Button, Card, Col, Container, Form, Row, Stack } from "react-bootstrap";
import useAxiosPrivate from "../hooks/useAxiosPrivate";
import { formatCodeBlocks, formatText } from "../utils";
import NavBar from "./NavBar";
Expand Down Expand Up @@ -93,11 +85,7 @@ export function Thread() {
Send
</Button>
<div className="vr" />
<Button
variant="outline-danger"
onClick={clearChat}
disabled={loading}
>
<Button variant="outline-danger" onClick={clearChat} disabled={loading}>
Reset
</Button>
</Stack>
Expand All @@ -111,15 +99,15 @@ export function Thread() {
{loading ? (
<>
<div
className="loading-skeleton"
className="loader"
style={{
marginBottom: "10px",
marginTop: "10px",
height: "70px",
}}
></div>
<div
className="loading-skeleton"
className="loader"
style={{
marginBottom: "10px",
marginTop: "10px",
Expand All @@ -133,22 +121,13 @@ export function Thread() {
</div>
<div>
{chatHistory.map((chatItem, index) => (
<Card
style={{ marginBottom: "10px", marginTop: "10px" }}
key={index}
>
<Card.Header>
{chatItem.role && chatItem.role === "user"
? "Question"
: "Answer"}
</Card.Header>
<Card style={{ marginBottom: "10px", marginTop: "10px" }} key={index}>
<Card.Header>{chatItem.role && chatItem.role === "user" ? "Question" : "Answer"}</Card.Header>
{chatItem.parts.map((part, i) => (
<Card.Body key={i}>
<Card.Text
dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(
formatCodeBlocks(formatText(part.text))
),
__html: DOMPurify.sanitize(formatCodeBlocks(formatText(part.text))),
}}
></Card.Text>
</Card.Body>
Expand Down
2 changes: 1 addition & 1 deletion presentation/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const BASE_URL = "http://localhost:4000";
export const BASE_URL = "http://localhost:3000";
export enum CHAT_PARAMS {
MATCH_COUNT = 3,
SIMILARITY_THRESHOLD = 0.7,
Expand Down
78 changes: 59 additions & 19 deletions presentation/src/index.css
Original file line number Diff line number Diff line change
@@ -1,25 +1,65 @@
.loading-skeleton {
background-color: #f0f0f0;
border-radius: 4px;
animation: loading 1s infinite ease-in-out;
background-color: #f0f0f0;
border-radius: 4px;
animation: loading 1s infinite ease-in-out;
}

@keyframes loading {
0% {
opacity: 0.5;
}

@keyframes loading {
0% { opacity: 0.5; }
50% { opacity: 1; }
100% { opacity: 0.5; }

50% {
opacity: 1;
}

.straight-line {
position: fixed;
top: 0;
left: 50%;
width: 1px;
height: 100vh;
background-color: black;
z-index: 999;
100% {
opacity: 0.5;
}
}

.straight-line {
position: fixed;
top: 0;
left: 50%;
width: 1px;
height: 100vh;
background-color: black;
z-index: 999;
}

body {
background-color: #f8f9fa;
}

body{
background-color: #f8f9fa;
}
.loader {
position: relative;
background-color: rgb(235, 235, 235);
max-width: 100%;
height: auto;
background: #efefee;
overflow: hidden;
border-radius: 4px;
margin-bottom: 4px;
}

.loader::after {
display: block;
content: "";
position: absolute;
width: 100%;
height: 100%;
transform: translateX(-100%);
background: linear-gradient(90deg, transparent, #f1f1f1, transparent);
background: linear-gradient(90deg,
transparent,
rgba(255, 255, 255, 0.4),
transparent);
animation: loading 1s infinite;
}

@keyframes loading {
100% {
transform: translateX(100%);
}
}

0 comments on commit 36ec297

Please sign in to comment.