Skip to content

Commit 559a44c

Browse files
committed
Support tool use, add example
1 parent 08d31da commit 559a44c

File tree

3 files changed

+81
-17
lines changed

3 files changed

+81
-17
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import SwiftUI
1010
import Tokenizers
1111

1212
struct ContentView: View {
13+
@Environment(DeviceStat.self) private var deviceStat
1314

14-
@State var prompt = ""
1515
@State var llm = LLMEvaluator()
16-
@Environment(DeviceStat.self) private var deviceStat
16+
@State var prompt = "What's the current weather in Paris?"
1717

1818
enum displayStyle: String, CaseIterable, Identifiable {
1919
case plain, markdown
@@ -34,6 +34,9 @@ struct ContentView: View {
3434
Text(llm.stat)
3535
}
3636
HStack {
37+
Toggle(isOn: $llm.includeWeatherTool) {
38+
Text("Include \"get current weather\" tool")
39+
}
3740
Spacer()
3841
if llm.running {
3942
ProgressView()
@@ -126,8 +129,6 @@ struct ContentView: View {
126129

127130
}
128131
.task {
129-
self.prompt = llm.modelConfiguration.defaultPrompt
130-
131132
// pre-load the weights on launch to speed up the first generation
132133
_ = try? await llm.load()
133134
}
@@ -154,13 +155,19 @@ class LLMEvaluator {
154155

155156
var running = false
156157

158+
var includeWeatherTool = false
159+
157160
var output = ""
158161
var modelInfo = ""
159162
var stat = ""
160163

161164
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
162165
/// more devices.
163-
let modelConfiguration = ModelRegistry.phi3_5_4bit
166+
// let modelConfiguration = ModelRegistry.phi3_5_4bit
167+
// let modelConfiguration = ModelRegistry.llama3_2_3B_4bit
168+
// let modelConfiguration = ModelRegistry.llama3_1_8B_4bit
169+
// let modelConfiguration = ModelRegistry.mistral7B4bit
170+
let modelConfiguration = ModelRegistry.qwen2_5_7b
164171

165172
/// parameters controlling the output
166173
let generateParameters = GenerateParameters(temperature: 0.6)
@@ -178,6 +185,29 @@ class LLMEvaluator {
178185

179186
var loadState = LoadState.idle
180187

188+
let currentWeatherToolSpec: [String: any Sendable] =
189+
[
190+
"type": "function",
191+
"function": [
192+
"name": "get_current_weather",
193+
"description": "Get the current weather in a given location",
194+
"parameters": [
195+
"type": "object",
196+
"properties": [
197+
"location": [
198+
"type": "string",
199+
"description": "The city and state, e.g. San Francisco, CA",
200+
] as [String: String],
201+
"unit": [
202+
"type": "string",
203+
"enum": ["celsius", "fahrenheit"],
204+
] as [String: any Sendable],
205+
] as [String: [String: any Sendable]],
206+
"required": ["location"],
207+
] as [String: any Sendable],
208+
] as [String: any Sendable],
209+
] as [String: any Sendable]
210+
181211
/// load and return the model -- can be called multiple times, subsequent calls will
182212
/// just return the loaded model
183213
func load() async throws -> ModelContainer {
@@ -222,18 +252,22 @@ class LLMEvaluator {
222252
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
223253

224254
let result = try await modelContainer.perform { context in
225-
let input = try await context.processor.prepare(input: .init(prompt: prompt))
255+
let input = try await context.processor.prepare(
256+
input: .init(
257+
messages: [
258+
["role": "system", "content": "You are a helpful assistant."],
259+
["role": "user", "content": prompt],
260+
], tools: includeWeatherTool ? [currentWeatherToolSpec] : nil))
226261
return try MLXLMCommon.generate(
227262
input: input, parameters: generateParameters, context: context
228263
) { tokens in
229-
// update the output -- this will make the view show the text as it generates
264+
// Show the text in the view as it generates
230265
if tokens.count % displayEveryNTokens == 0 {
231266
let text = context.tokenizer.decode(tokens: tokens)
232267
Task { @MainActor in
233268
self.output = text
234269
}
235270
}
236-
237271
if tokens.count >= maxTokens {
238272
return .stop
239273
} else {

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,14 @@ public class ModelRegistry: @unchecked Sendable {
143143
defaultPrompt: "What is the difference between lettuce and cabbage?"
144144
)
145145

146-
static public let qwen205b4bit = ModelConfiguration(
147-
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
148-
overrideTokenizer: "PreTrainedTokenizer",
149-
defaultPrompt: "why is the sky blue?"
146+
static public let qwen2_5_7b = ModelConfiguration(
147+
id: "mlx-community/Qwen2.5-7B-Instruct-4bit",
148+
defaultPrompt: "Why is the sky blue?"
149+
)
150+
151+
static public let qwen2_5_1_5b = ModelConfiguration(
152+
id: "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
153+
defaultPrompt: "Why is the sky blue?"
150154
)
151155

152156
static public let openelm270m4bit = ModelConfiguration(
@@ -192,7 +196,8 @@ public class ModelRegistry: @unchecked Sendable {
192196
phi3_5MoE,
193197
phi3_5_4bit,
194198
phi4bit,
195-
qwen205b4bit,
199+
qwen2_5_7b,
200+
qwen2_5_1_5b,
196201
smolLM_135M_4bit,
197202
]
198203
}
@@ -229,7 +234,13 @@ private struct LLMUserInputProcessor: UserInputProcessor {
229234
func prepare(input: UserInput) throws -> LMInput {
230235
do {
231236
let messages = input.prompt.asMessages()
232-
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
237+
let promptTokens = try tokenizer.applyChatTemplate(
238+
messages: messages, tools: input.tools, additionalContext: input.additionalContext)
239+
240+
let promptDecoded = try tokenizer.decode(tokens: promptTokens)
241+
242+
print(promptDecoded)
243+
233244
return LMInput(tokens: MLXArray(promptTokens))
234245
} catch {
235246
// #150 -- it might be a TokenizerError.chatTemplate("No chat template was specified")

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import CoreImage
44
import Foundation
55
import MLX
6+
import Tokenizers
67

78
/// Container for raw user input.
89
///
@@ -108,23 +109,41 @@ public struct UserInput: Sendable {
108109
}
109110

110111
public var prompt: Prompt
112+
public var tools: [ToolSpec]?
113+
/// Additional values provided for the chat template rendering context
114+
public var additionalContext: [String: Any]?
111115
public var images = [Image]()
112116
public var processing: Processing = .init()
113117

114-
public init(prompt: String, images: [Image] = [Image]()) {
118+
public init(
119+
prompt: String, images: [Image] = [Image](), tools: [ToolSpec]? = nil,
120+
additionalContext: [String: Any]? = nil
121+
) {
115122
self.prompt = .text(prompt)
116123
self.images = images
124+
self.tools = tools
125+
self.additionalContext = additionalContext
117126
}
118127

119-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
128+
public init(
129+
messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil,
130+
additionalContext: [String: Any]? = nil
131+
) {
120132
self.prompt = .messages(messages)
121133
self.images = images
134+
self.tools = tools
135+
self.additionalContext = additionalContext
122136
}
123137

124-
public init(prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init()) {
138+
public init(
139+
prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init(),
140+
tools: [ToolSpec]? = nil, additionalContext: [String: Any]? = nil
141+
) {
125142
self.prompt = prompt
126143
self.images = images
127144
self.processing = processing
145+
self.tools = tools
146+
self.additionalContext = additionalContext
128147
}
129148
}
130149

0 commit comments

Comments
 (0)