Skip to content

Commit

Permalink
Improve token counting and use token count from github response
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
  • Loading branch information
deathbeam committed Nov 6, 2024
1 parent 760d291 commit 18d5175
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 67 deletions.
158 changes: 95 additions & 63 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
---@field system_prompt string?
---@field model string?
---@field temperature number?
---@field on_done nil|fun(response: string, token_count: number?):nil
---@field on_done nil|fun(response: string, token_count: number?, token_max_count: number?):nil
---@field on_progress nil|fun(response: string):nil
---@field on_error nil|fun(err: string):nil

Expand Down Expand Up @@ -290,13 +290,20 @@ local function generate_headers(token, sessionid, machineid)
return headers
end

local function count_history_tokens(history)
local count = 0
for _, msg in ipairs(history) do
count = count + tiktoken.count(msg.content)
end
return count
end

local Copilot = class(function(self, proxy, allow_insecure)
self.proxy = proxy
self.allow_insecure = allow_insecure
self.github_token = get_cached_token()
self.history = {}
self.token = nil
self.token_count = 0
self.sessionid = nil
self.machineid = machine_id()
self.current_job = nil
Expand Down Expand Up @@ -482,12 +489,12 @@ function Copilot:ask(prompt, opts)
local on_progress = opts.on_progress
local on_error = opts.on_error

log.debug('System prompt: ' .. system_prompt)
log.trace('System prompt: ' .. system_prompt)
log.trace('Selection: ' .. selection)
log.debug('Prompt: ' .. prompt)
log.debug('Embeddings: ' .. #embeddings)
log.debug('Filename: ' .. filename)
log.debug('Filetype: ' .. filetype)
log.debug('Selection: ' .. selection)
log.debug('Model: ' .. model)
log.debug('Temperature: ' .. temperature)

Expand All @@ -510,26 +517,49 @@ function Copilot:ask(prompt, opts)
local embeddings_message = generate_embeddings_message(embeddings)

tiktoken.load(tokenizer, function()
-- Count tokens
self.token_count = self.token_count + tiktoken.count(prompt)
local current_count = 0
current_count = current_count + tiktoken.count(system_prompt)
current_count = current_count + tiktoken.count(selection_message)
-- Count required tokens that we cannot reduce
local prompt_tokens = tiktoken.count(prompt)
local system_tokens = tiktoken.count(system_prompt)
local selection_tokens = tiktoken.count(selection_message)
local required_tokens = prompt_tokens + system_tokens + selection_tokens

-- Reserve space for first embedding if its smaller than half of max tokens
local reserved_tokens = 0
if #embeddings_message.files > 0 then
local file_tokens = tiktoken.count(embeddings_message.files[1])
if file_tokens < max_tokens / 2 then
reserved_tokens = tiktoken.count(embeddings_message.header) + file_tokens
end
end

-- Limit the number of files to send
-- Calculate how many tokens we can use for history
local history_limit = max_tokens - required_tokens - reserved_tokens
local history_tokens = count_history_tokens(self.history)

-- If we're over history limit, truncate history from the beginning
while history_tokens > history_limit and #self.history > 0 do
local removed = table.remove(self.history, 1)
history_tokens = history_tokens - tiktoken.count(removed.content)
end

-- Now add as many files as possible with remaining token budget
local remaining_tokens = max_tokens - required_tokens - history_tokens
if #embeddings_message.files > 0 then
remaining_tokens = remaining_tokens - tiktoken.count(embeddings_message.header)
local filtered_files = {}
current_count = current_count + tiktoken.count(embeddings_message.header)
for _, file in ipairs(embeddings_message.files) do
local file_count = current_count + tiktoken.count(file)
if file_count + self.token_count < max_tokens then
current_count = file_count
local file_tokens = tiktoken.count(file)
if remaining_tokens - file_tokens >= 0 then
remaining_tokens = remaining_tokens - file_tokens
table.insert(filtered_files, file)
else
break
end
end
embeddings_message.files = filtered_files
end

-- Generate the request
local url = 'https://api.githubcopilot.com/chat/completions'
local body = vim.json.encode(
generate_ask_request(
Expand All @@ -543,47 +573,67 @@ function Copilot:ask(prompt, opts)
)
)

-- Add the prompt to history after we have encoded the request
table.insert(self.history, {
content = prompt,
role = 'user',
})

local errored = false
local last_message = nil
local full_response = ''

local function stream_func(err, line)
if not line or errored then
return
end

if err or vim.startswith(line, '{"error"') then
err = 'Failed to get response: ' .. (err and vim.inspect(err) or line)
local function handle_error(error_msg)
if not errored then
errored = true
log.error(err)
log.error(error_msg)
if self.current_job and on_error then
on_error(err)
on_error(error_msg)
end
end
end

local function callback_func(response)
if not response then
handle_error('Failed to get response')
return
end

line = line:gsub('data: ', '')
if line == '' then
if response.status ~= 200 then
handle_error(
'Failed to get response: ' .. tostring(response.status) .. '\n' .. response.body
)
return
elseif line == '[DONE]' then
log.trace('Full response: ' .. full_response)
log.debug('Last message: ' .. vim.inspect(last_message))
self.token_count = self.token_count + tiktoken.count(full_response)
end

if self.current_job and on_done then
on_done(full_response, self.token_count + current_count)
end
log.trace('Full response: ' .. full_response)
log.debug('Last message: ' .. vim.inspect(last_message))

table.insert(self.history, {
content = full_response,
role = 'assistant',
})
if on_done then
on_done(
full_response,
last_message and last_message.usage and last_message.usage.total_tokens,
max_tokens
)
end

table.insert(self.history, {
content = prompt,
role = 'user',
})

table.insert(self.history, {
content = full_response,
role = 'assistant',
})
end

local function stream_func(err, line)
if not line or errored or not self.current_job then
return
end

if err or vim.startswith(line, '{"error"') then
handle_error('Failed to get response: ' .. (err and vim.inspect(err) or line))
return
end

line = line:gsub('^%s*data: ', '')
if line == '' or line == '[DONE]' then
return
end

Expand All @@ -595,8 +645,7 @@ function Copilot:ask(prompt, opts)
})

if not ok then
err = 'Failed to parse response: ' .. vim.inspect(content) .. '\n' .. line
log.error(err)
handle_error('Failed to parse response: ' .. vim.inspect(content) .. '\n' .. line)
return
end

Expand All @@ -613,27 +662,10 @@ function Copilot:ask(prompt, opts)
return
end

if self.current_job and on_progress then
if on_progress then
on_progress(content)
end

if is_full then
log.trace('Full response: ' .. content)
log.debug('Last message: ' .. vim.inspect(last_message))
self.token_count = self.token_count + tiktoken.count(content)

if self.current_job and on_done then
on_done(content, self.token_count + current_count)
end

table.insert(self.history, {
content = content,
role = 'assistant',
})
return
end

-- Collect full response incrementally so we can insert it to history later
full_response = full_response .. content
end

Expand All @@ -646,6 +678,7 @@ function Copilot:ask(prompt, opts)
proxy = self.proxy,
insecure = self.allow_insecure,
stream = stream_func,
callback = callback_func,
on_error = function(err)
err = 'Failed to get response: ' .. vim.inspect(err)
log.error(err)
Expand Down Expand Up @@ -794,7 +827,6 @@ end
function Copilot:reset()
local stopped = self:stop()
self.history = {}
self.token_count = 0
return stopped
end

Expand Down
2 changes: 1 addition & 1 deletion lua/CopilotChat/health.lua
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function M.check()
ok('tiktoken_core: installed')
else
warn(
'tiktoken_core: missing, optional for token counting. See README for installation instructions.'
'tiktoken_core: missing, optional for accurate token counting. See README for installation instructions.'
)
end

Expand Down
6 changes: 3 additions & 3 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,12 @@ function M.ask(prompt, config, source)
model = config.model,
temperature = config.temperature,
on_error = on_error,
on_done = function(response, token_count)
on_done = function(response, token_count, token_max_count)
vim.schedule(function()
append('\n\n' .. config.question_header .. config.separator .. '\n\n', config)
state.response = response
if token_count and token_count > 0 then
state.chat:finish(token_count .. ' tokens used')
if token_count and token_max_count and token_count > 0 then
state.chat:finish(token_count .. '/' .. token_max_count .. ' tokens used')
else
state.chat:finish()
end
Expand Down

0 comments on commit 18d5175

Please sign in to comment.