mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-08 11:50:31 -05:00
feat(function): Add tool streaming, XML Tool Call Parsing Support (#7865)
* feat(function): Add XML Tool Call Parsing Support Extend the function parsing system in LocalAI to support XML-style tool calls, similar to how JSON tool calls are currently parsed. This will allow models that return XML format (like <tool_call><function=name><parameter=key>value</parameter></function></tool_call>) to be properly parsed alongside text content. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * thinking before tool calls, more strict support for corner cases with no tools Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Support streaming tools Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Iterative JSON Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Iterative parsing Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Consume JSON marker Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fixup Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * add tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix pending TODOs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Don't run other parsing with ParseRegex Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
9d3da0bed5
commit
21c84f432f
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,431 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// JSONStackElementType represents the type of JSON stack element
|
||||
type JSONStackElementType int
|
||||
|
||||
const (
|
||||
JSONStackElementObject JSONStackElementType = iota
|
||||
JSONStackElementKey
|
||||
JSONStackElementArray
|
||||
)
|
||||
|
||||
// JSONStackElement represents an element in the JSON parsing stack
|
||||
type JSONStackElement struct {
|
||||
Type JSONStackElementType
|
||||
Key string
|
||||
}
|
||||
|
||||
// JSONErrorLocator tracks JSON parsing state and errors
|
||||
type JSONErrorLocator struct {
|
||||
position int
|
||||
foundError bool
|
||||
lastToken string
|
||||
exceptionMessage string
|
||||
stack []JSONStackElement
|
||||
}
|
||||
|
||||
// parseJSONWithStack parses JSON with stack tracking, matching llama.cpp's common_json_parse
|
||||
// Returns the parsed JSON value, whether it was healed, and any error
|
||||
func parseJSONWithStack(input string, healingMarker string) (any, bool, string, error) {
|
||||
if healingMarker == "" {
|
||||
// No healing marker, just try to parse normally
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(input), &result); err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
return result, false, "", nil
|
||||
}
|
||||
|
||||
// Try to parse complete JSON first
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(input), &result); err == nil {
|
||||
return result, false, "", nil
|
||||
}
|
||||
|
||||
// Parsing failed, need to track stack and heal
|
||||
errLoc := &JSONErrorLocator{
|
||||
position: 0,
|
||||
foundError: false,
|
||||
stack: make([]JSONStackElement, 0),
|
||||
}
|
||||
|
||||
// Parse with stack tracking to find where error occurs
|
||||
errorPos, err := parseJSONWithStackTracking(input, errLoc)
|
||||
if err == nil && !errLoc.foundError {
|
||||
// No error found, should have parsed successfully
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(input), &result); err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
return result, false, "", nil
|
||||
}
|
||||
|
||||
if !errLoc.foundError || len(errLoc.stack) == 0 {
|
||||
// Can't heal without stack information
|
||||
return nil, false, "", errors.New("incomplete JSON")
|
||||
}
|
||||
|
||||
// Build closing braces/brackets from stack
|
||||
closing := ""
|
||||
for i := len(errLoc.stack) - 1; i >= 0; i-- {
|
||||
el := errLoc.stack[i]
|
||||
if el.Type == JSONStackElementObject {
|
||||
closing += "}"
|
||||
} else if el.Type == JSONStackElementArray {
|
||||
closing += "]"
|
||||
}
|
||||
// Keys don't add closing characters
|
||||
}
|
||||
|
||||
// Get the partial input up to error position
|
||||
partialInput := input
|
||||
if errorPos > 0 && errorPos < len(input) {
|
||||
partialInput = input[:errorPos]
|
||||
}
|
||||
|
||||
// Find last non-space character
|
||||
lastNonSpacePos := strings.LastIndexFunc(partialInput, func(r rune) bool {
|
||||
return !unicode.IsSpace(r)
|
||||
})
|
||||
if lastNonSpacePos == -1 {
|
||||
return nil, false, "", errors.New("cannot heal a truncated JSON that stopped in an unknown location")
|
||||
}
|
||||
lastNonSpaceChar := rune(partialInput[lastNonSpacePos])
|
||||
|
||||
// Check if we stopped on a number
|
||||
wasMaybeNumber := func() bool {
|
||||
if len(partialInput) > 0 && unicode.IsSpace(rune(partialInput[len(partialInput)-1])) {
|
||||
return false
|
||||
}
|
||||
return unicode.IsDigit(lastNonSpaceChar) ||
|
||||
lastNonSpaceChar == '.' ||
|
||||
lastNonSpaceChar == 'e' ||
|
||||
lastNonSpaceChar == 'E' ||
|
||||
lastNonSpaceChar == '-'
|
||||
}
|
||||
|
||||
// Check for partial unicode escape sequences
|
||||
partialUnicodeRegex := regexp.MustCompile(`\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$`)
|
||||
unicodeMarkerPadding := "udc00"
|
||||
lastUnicodeMatch := partialUnicodeRegex.FindStringSubmatch(partialInput)
|
||||
if lastUnicodeMatch != nil {
|
||||
// Pad the escape sequence
|
||||
unicodeMarkerPadding = strings.Repeat("0", 6-len(lastUnicodeMatch[0]))
|
||||
// Check if it's a high surrogate
|
||||
if len(lastUnicodeMatch[0]) >= 4 {
|
||||
seq := lastUnicodeMatch[0]
|
||||
if seq[0] == '\\' && seq[1] == 'u' {
|
||||
third := strings.ToLower(string(seq[2]))
|
||||
if third == "d" {
|
||||
fourth := strings.ToLower(string(seq[3]))
|
||||
if fourth == "8" || fourth == "9" || fourth == "a" || fourth == "b" {
|
||||
// High surrogate, add low surrogate
|
||||
unicodeMarkerPadding += "\\udc00"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
canParse := func(str string) bool {
|
||||
var test any
|
||||
return json.Unmarshal([]byte(str), &test) == nil
|
||||
}
|
||||
|
||||
// Heal based on stack top element type
|
||||
healedJSON := partialInput
|
||||
jsonDumpMarker := ""
|
||||
topElement := errLoc.stack[len(errLoc.stack)-1]
|
||||
|
||||
if topElement.Type == JSONStackElementKey {
|
||||
// We're inside an object value
|
||||
if lastNonSpaceChar == ':' && canParse(healedJSON+"1"+closing) {
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if canParse(healedJSON + ": 1" + closing) {
|
||||
jsonDumpMarker = ":\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if lastNonSpaceChar == '{' && canParse(healedJSON+closing) {
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\": 1" + closing
|
||||
} else if canParse(healedJSON + "\"" + closing) {
|
||||
jsonDumpMarker = healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if len(healedJSON) > 0 && healedJSON[len(healedJSON)-1] == '\\' && canParse(healedJSON+"\\\""+closing) {
|
||||
jsonDumpMarker = "\\" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if canParse(healedJSON + unicodeMarkerPadding + "\"" + closing) {
|
||||
jsonDumpMarker = unicodeMarkerPadding + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else {
|
||||
// Find last colon and cut back
|
||||
lastColon := strings.LastIndex(healedJSON, ":")
|
||||
if lastColon == -1 {
|
||||
return nil, false, "", errors.New("cannot heal a truncated JSON that stopped in an unknown location")
|
||||
}
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON = healedJSON[:lastColon+1] + jsonDumpMarker + "\"" + closing
|
||||
}
|
||||
} else if topElement.Type == JSONStackElementArray {
|
||||
// We're inside an array
|
||||
if (lastNonSpaceChar == ',' || lastNonSpaceChar == '[') && canParse(healedJSON+"1"+closing) {
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if canParse(healedJSON + "\"" + closing) {
|
||||
jsonDumpMarker = healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if len(healedJSON) > 0 && healedJSON[len(healedJSON)-1] == '\\' && canParse(healedJSON+"\\\""+closing) {
|
||||
jsonDumpMarker = "\\" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if canParse(healedJSON + unicodeMarkerPadding + "\"" + closing) {
|
||||
jsonDumpMarker = unicodeMarkerPadding + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else if !wasMaybeNumber() && canParse(healedJSON+", 1"+closing) {
|
||||
jsonDumpMarker = ",\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\"" + closing
|
||||
} else {
|
||||
lastBracketOrComma := strings.LastIndexAny(healedJSON, "[,")
|
||||
if lastBracketOrComma == -1 {
|
||||
return nil, false, "", errors.New("cannot heal a truncated JSON array stopped in an unknown location")
|
||||
}
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON = healedJSON[:lastBracketOrComma+1] + jsonDumpMarker + "\"" + closing
|
||||
}
|
||||
} else if topElement.Type == JSONStackElementObject {
|
||||
// We're inside an object (expecting a key)
|
||||
if (lastNonSpaceChar == '{' && canParse(healedJSON+closing)) ||
|
||||
(lastNonSpaceChar == ',' && canParse(healedJSON+"\"\": 1"+closing)) {
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\": 1" + closing
|
||||
} else if !wasMaybeNumber() && canParse(healedJSON+",\"\": 1"+closing) {
|
||||
jsonDumpMarker = ",\"" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\": 1" + closing
|
||||
} else if canParse(healedJSON + "\": 1" + closing) {
|
||||
jsonDumpMarker = healingMarker
|
||||
healedJSON += jsonDumpMarker + "\": 1" + closing
|
||||
} else if len(healedJSON) > 0 && healedJSON[len(healedJSON)-1] == '\\' && canParse(healedJSON+"\\\": 1"+closing) {
|
||||
jsonDumpMarker = "\\" + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\": 1" + closing
|
||||
} else if canParse(healedJSON + unicodeMarkerPadding + "\": 1" + closing) {
|
||||
jsonDumpMarker = unicodeMarkerPadding + healingMarker
|
||||
healedJSON += jsonDumpMarker + "\": 1" + closing
|
||||
} else {
|
||||
lastColon := strings.LastIndex(healedJSON, ":")
|
||||
if lastColon == -1 {
|
||||
return nil, false, "", errors.New("cannot heal a truncated JSON object stopped in an unknown location")
|
||||
}
|
||||
jsonDumpMarker = "\"" + healingMarker
|
||||
healedJSON = healedJSON[:lastColon+1] + jsonDumpMarker + "\"" + closing
|
||||
}
|
||||
} else {
|
||||
return nil, false, "", errors.New("cannot heal a truncated JSON object stopped in an unknown location")
|
||||
}
|
||||
|
||||
// Try to parse the healed JSON
|
||||
var healedValue any
|
||||
if err := json.Unmarshal([]byte(healedJSON), &healedValue); err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
|
||||
// Remove healing marker from result
|
||||
cleaned := removeHealingMarkerFromJSONAny(healedValue, healingMarker)
|
||||
return cleaned, true, jsonDumpMarker, nil
|
||||
}
|
||||
|
||||
// parseJSONWithStackTracking parses JSON while tracking the stack structure
|
||||
// Returns the error position and any error encountered
|
||||
// This implements stack tracking similar to llama.cpp's json_error_locator
|
||||
func parseJSONWithStackTracking(input string, errLoc *JSONErrorLocator) (int, error) {
|
||||
// First, try to parse to get exact error position
|
||||
decoder := json.NewDecoder(strings.NewReader(input))
|
||||
var test any
|
||||
err := decoder.Decode(&test)
|
||||
if err != nil {
|
||||
errLoc.foundError = true
|
||||
errLoc.exceptionMessage = err.Error()
|
||||
|
||||
var errorPos int
|
||||
if syntaxErr, ok := err.(*json.SyntaxError); ok {
|
||||
errorPos = int(syntaxErr.Offset)
|
||||
errLoc.position = errorPos
|
||||
} else {
|
||||
// Fallback: use end of input
|
||||
errorPos = len(input)
|
||||
errLoc.position = errorPos
|
||||
}
|
||||
|
||||
// Now build the stack by parsing up to the error position
|
||||
// This matches llama.cpp's approach of tracking stack during SAX parsing
|
||||
partialInput := input
|
||||
if errorPos > 0 && errorPos < len(input) {
|
||||
partialInput = input[:errorPos]
|
||||
}
|
||||
|
||||
// Track stack by parsing character by character up to error
|
||||
pos := 0
|
||||
inString := false
|
||||
escape := false
|
||||
keyStart := -1
|
||||
keyEnd := -1
|
||||
|
||||
for pos < len(partialInput) {
|
||||
ch := partialInput[pos]
|
||||
|
||||
if escape {
|
||||
escape = false
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\\' {
|
||||
escape = true
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
if !inString {
|
||||
// Starting a string
|
||||
inString = true
|
||||
// Check if we're in an object context (expecting a key)
|
||||
if len(errLoc.stack) > 0 {
|
||||
top := errLoc.stack[len(errLoc.stack)-1]
|
||||
if top.Type == JSONStackElementObject {
|
||||
// This could be a key
|
||||
keyStart = pos + 1 // Start after quote
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Ending a string
|
||||
inString = false
|
||||
if keyStart != -1 {
|
||||
// This was potentially a key, extract it
|
||||
keyEnd = pos
|
||||
key := partialInput[keyStart:keyEnd]
|
||||
|
||||
// Look ahead to see if next non-whitespace is ':'
|
||||
nextPos := pos + 1
|
||||
for nextPos < len(partialInput) && unicode.IsSpace(rune(partialInput[nextPos])) {
|
||||
nextPos++
|
||||
}
|
||||
if nextPos < len(partialInput) && partialInput[nextPos] == ':' {
|
||||
// This is a key, add it to stack
|
||||
errLoc.stack = append(errLoc.stack, JSONStackElement{Type: JSONStackElementKey, Key: key})
|
||||
}
|
||||
keyStart = -1
|
||||
keyEnd = -1
|
||||
}
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle stack operations (outside strings)
|
||||
if ch == '{' {
|
||||
errLoc.stack = append(errLoc.stack, JSONStackElement{Type: JSONStackElementObject})
|
||||
} else if ch == '}' {
|
||||
// Pop object and any key on top (keys are popped when value starts, but handle here too)
|
||||
for len(errLoc.stack) > 0 {
|
||||
top := errLoc.stack[len(errLoc.stack)-1]
|
||||
errLoc.stack = errLoc.stack[:len(errLoc.stack)-1]
|
||||
if top.Type == JSONStackElementObject {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if ch == '[' {
|
||||
errLoc.stack = append(errLoc.stack, JSONStackElement{Type: JSONStackElementArray})
|
||||
} else if ch == ']' {
|
||||
// Pop array
|
||||
for len(errLoc.stack) > 0 {
|
||||
top := errLoc.stack[len(errLoc.stack)-1]
|
||||
errLoc.stack = errLoc.stack[:len(errLoc.stack)-1]
|
||||
if top.Type == JSONStackElementArray {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if ch == ':' {
|
||||
// Colon means we're starting a value, pop the key if it's on stack
|
||||
if len(errLoc.stack) > 0 && errLoc.stack[len(errLoc.stack)-1].Type == JSONStackElementKey {
|
||||
errLoc.stack = errLoc.stack[:len(errLoc.stack)-1]
|
||||
}
|
||||
}
|
||||
// Note: commas and whitespace don't affect stack structure
|
||||
|
||||
pos++
|
||||
}
|
||||
|
||||
return errorPos, err
|
||||
}
|
||||
|
||||
// No error, parse was successful - build stack anyway for completeness
|
||||
// (though we shouldn't need healing in this case)
|
||||
pos := 0
|
||||
inString := false
|
||||
escape := false
|
||||
|
||||
for pos < len(input) {
|
||||
ch := input[pos]
|
||||
|
||||
if escape {
|
||||
escape = false
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\\' {
|
||||
escape = true
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '{' {
|
||||
errLoc.stack = append(errLoc.stack, JSONStackElement{Type: JSONStackElementObject})
|
||||
} else if ch == '}' {
|
||||
for len(errLoc.stack) > 0 {
|
||||
top := errLoc.stack[len(errLoc.stack)-1]
|
||||
errLoc.stack = errLoc.stack[:len(errLoc.stack)-1]
|
||||
if top.Type == JSONStackElementObject {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if ch == '[' {
|
||||
errLoc.stack = append(errLoc.stack, JSONStackElement{Type: JSONStackElementArray})
|
||||
} else if ch == ']' {
|
||||
for len(errLoc.stack) > 0 {
|
||||
top := errLoc.stack[len(errLoc.stack)-1]
|
||||
errLoc.stack = errLoc.stack[:len(errLoc.stack)-1]
|
||||
if top.Type == JSONStackElementArray {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pos++
|
||||
}
|
||||
|
||||
return len(input), nil
|
||||
}
|
||||
+1317
-7
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user