Skip to content

Commit

Permalink
improve chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 1, 2025
1 parent 3cddaf1 commit 7e6e740
Showing 1 changed file with 88 additions and 70 deletions.
158 changes: 88 additions & 70 deletions agixt/providers/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,38 @@ def score_chunk(chunk: str, keywords: set) -> int:
return score


def chunk_content(text: str, chunk_size: int) -> List[str]:
def chunk_content(text: str, chunk_size: int, max_tokens: int = 60000) -> List[str]:
"""
Split content into chunks while respecting both character and token limits.
Args:
text: Text to chunk
chunk_size: Target size for each chunk in characters
max_tokens: Maximum tokens allowed for processing (default: 60000 for Deepseek)
"""
doc = nlp(text)
sentences = list(doc.sents)
content_chunks = []
chunk = []
chunk_len = 0
chunk_text = ""
total_len = 0

keywords = set(extract_keywords(doc=doc, limit=10))

for sentence in sentences:
sentence_tokens = len(sentence)
# Estimate tokens (rough approximation: 4 characters per token)
estimated_total_tokens = (total_len + len(str(sentence))) // 4

if estimated_total_tokens > max_tokens:
break

if chunk_len + sentence_tokens > chunk_size and chunk:
chunk_text = " ".join(token.text for token in chunk)
content_chunks.append((score_chunk(chunk_text, keywords), chunk_text))
score = score_chunk(chunk_text, keywords)
content_chunks.append((score, chunk_text))
total_len += len(chunk_text)
chunk = []
chunk_len = 0

Expand All @@ -34,11 +54,23 @@ def chunk_content(text: str, chunk_size: int) -> List[str]:

if chunk:
chunk_text = " ".join(token.text for token in chunk)
content_chunks.append((score_chunk(chunk_text, keywords), chunk_text))
score = score_chunk(chunk_text, keywords)
content_chunks.append((score, chunk_text))

# Sort the chunks by their score in descending order before returning them
# Sort by score and take only enough chunks to stay under token limit
content_chunks.sort(key=lambda x: x[0], reverse=True)
return [chunk_text for score, chunk_text in content_chunks]
result_chunks = []
total_len = 0

for score, chunk_text in content_chunks:
# Estimate tokens for this chunk
chunk_tokens = len(chunk_text) // 4
if total_len + chunk_tokens > max_tokens:
break
result_chunks.append(chunk_text)
total_len += chunk_tokens

return result_chunks


class RotationProvider:
Expand Down Expand Up @@ -69,89 +101,75 @@ async def _analyze_chunk(
self, chunk: str, chunk_index: int, prompt: str
) -> List[int]:
"""Analyze a single large chunk to identify relevant smaller chunks."""
small_chunks = chunk_content(chunk, self.SMALL_CHUNK_SIZE)
# Use smaller max_tokens to leave room for prompt and completion
small_chunks = chunk_content(chunk, self.SMALL_CHUNK_SIZE, max_tokens=40000)
if not small_chunks:
return []

# Process small chunks in batches to stay within token limits
MAX_CHUNKS_PER_PROMPT = 5 # Adjust this based on actual token usage
results = []

for batch_start in range(0, len(small_chunks), MAX_CHUNKS_PER_PROMPT):
batch_end = min(batch_start + MAX_CHUNKS_PER_PROMPT, len(small_chunks))
batch_chunks = small_chunks[batch_start:batch_end]

analysis_prompt = (
f"Below is part {batch_start//MAX_CHUNKS_PER_PROMPT + 1} of chunk {chunk_index + 1}, "
f"containing sub-chunks {batch_start + 1} to {batch_end} of the total {len(small_chunks)} sub-chunks.\n"
"Analyze which sub-chunks are relevant to answering the query.\n"
"Respond ONLY with comma-separated sub-chunk numbers (using the original full numbering).\n"
"Example response format: 1,4,7\n"
"If no sub-chunks are relevant, respond with: none\n\n"
f"Query: {prompt}\n\n"
"Sub-chunks:\n"
)
analysis_prompt = (
f"Below is chunk {chunk_index + 1} of a larger codebase, split into {len(small_chunks)} "
f"sub-chunks, followed by a user query.\n"
"Analyze which sub-chunks are relevant to answering the query.\n"
"Respond ONLY with comma-separated sub-chunk numbers (1-based indexing).\n"
"Example response format: 1,4,7\n\n"
f"Query: {prompt}\n\n"
"Sub-chunks:\n"
)

for i, small_chunk in enumerate(batch_chunks, batch_start + 1):
analysis_prompt += f"\nSUB-CHUNK {i}:\n{small_chunk}\n"
for i, small_chunk in enumerate(small_chunks, 1):
analysis_prompt += f"\nSUB-CHUNK {i}:\n{small_chunk}\n"

try:
agent = Agent(
agent_name=self.agent_name,
user=self.user,
ApiClient=self.ApiClient,
)
if "agent_name" in self.AGENT_SETTINGS:
del self.AGENT_SETTINGS["agent_name"]
if "user" in self.AGENT_SETTINGS:
del self.AGENT_SETTINGS["user"]
if "ApiClient" in self.AGENT_SETTINGS:
del self.AGENT_SETTINGS["ApiClient"]
agent.PROVIDER = Providers(
name=self.ANALYSIS_PROVIDER,
ApiClient=self.ApiClient,
agent_name=self.agent_name,
user=self.user,
**self.AGENT_SETTINGS,
)
try:
agent = Agent(
agent_name=self.agent_name,
user=self.user,
ApiClient=self.ApiClient,
result = await agent.inference(prompt=analysis_prompt)
except Exception as e:
logging.error(
f"Chunk analysis failed for chunk {chunk_index + 1}: {str(e)}"
)
if "agent_name" in self.AGENT_SETTINGS:
del self.AGENT_SETTINGS["agent_name"]
if "user" in self.AGENT_SETTINGS:
del self.AGENT_SETTINGS["user"]
if "ApiClient" in self.AGENT_SETTINGS:
del self.AGENT_SETTINGS["ApiClient"]
agent.PROVIDER = Providers(
name=self.ANALYSIS_PROVIDER,
name="rotation",
ApiClient=self.ApiClient,
agent_name=self.agent_name,
user=self.user,
**self.AGENT_SETTINGS,
)
try:
result = await agent.inference(prompt=analysis_prompt)
except Exception as e:
logging.error(
f"Chunk analysis failed for batch {batch_start//MAX_CHUNKS_PER_PROMPT + 1} of chunk {chunk_index + 1}: {str(e)}"
)
agent.PROVIDER = Providers(
name="rotation",
ApiClient=self.ApiClient,
agent_name=self.agent_name,
user=self.user,
**self.AGENT_SETTINGS,
)
result = await agent.inference(prompt=analysis_prompt)
result = await agent.inference(prompt=analysis_prompt)

if result.strip().lower() != "none":
# Parse comma-separated numbers, convert to 0-based indexing
chunk_numbers = [int(n.strip()) - 1 for n in result.split(",")]
# Validate chunk numbers
valid_numbers = [
n for n in chunk_numbers if 0 <= n < len(small_chunks)
]
results.extend(valid_numbers)
# Parse comma-separated numbers, convert to 0-based indexing
chunk_numbers = [int(n.strip()) - 1 for n in result.split(",")]
# Validate chunk numbers
valid_numbers = [n for n in chunk_numbers if 0 <= n < len(small_chunks)]

except Exception as e:
logging.error(
f"Batch analysis failed for chunk {chunk_index + 1}, batch {batch_start//MAX_CHUNKS_PER_PROMPT + 1}: {str(e)}"
if not valid_numbers:
logging.warning(
f"No valid chunk numbers returned for chunk {chunk_index + 1}, using all sub-chunks"
)
# On complete failure, include all chunks from this batch
results.extend(range(batch_start, batch_end))
return list(range(len(small_chunks)))

if not results:
logging.warning(
f"No valid chunk numbers returned for any batch in chunk {chunk_index + 1}, using all sub-chunks"
return valid_numbers
except Exception as e:
logging.error(
f"Chunk analysis failed for chunk {chunk_index + 1}: {str(e)}"
)
return list(range(len(small_chunks)))

return sorted(set(results)) # Remove duplicates and sort
return list(range(len(small_chunks))) # Return all sub-chunks on failure

async def _get_relevant_chunks(self, text: str, prompt: str) -> str:
"""Split text into large chunks and analyze them in parallel."""
Expand Down

0 comments on commit 7e6e740

Please sign in to comment.