Skip to content

Commit 50c3529

Browse files
committed
Support tool use, add example
1 parent 08d31da commit 50c3529

File tree

3 files changed

+78
-17
lines changed

3 files changed

+78
-17
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import SwiftUI
1010
import Tokenizers
1111

1212
struct ContentView: View {
13+
@Environment(DeviceStat.self) private var deviceStat
1314

14-
@State var prompt = ""
1515
@State var llm = LLMEvaluator()
16-
@Environment(DeviceStat.self) private var deviceStat
16+
@State var prompt = "What's the current weather in Paris?"
1717

1818
enum displayStyle: String, CaseIterable, Identifiable {
1919
case plain, markdown
@@ -34,6 +34,10 @@ struct ContentView: View {
3434
Text(llm.stat)
3535
}
3636
HStack {
37+
Toggle(isOn: $llm.includeWeatherTool) {
38+
Text("Include \"get current weather\" tool")
39+
}
40+
.frame(maxWidth: 350, alignment: .leading)
3741
Spacer()
3842
if llm.running {
3943
ProgressView()
@@ -126,8 +130,6 @@ struct ContentView: View {
126130

127131
}
128132
.task {
129-
self.prompt = llm.modelConfiguration.defaultPrompt
130-
131133
// pre-load the weights on launch to speed up the first generation
132134
_ = try? await llm.load()
133135
}
@@ -154,13 +156,20 @@ class LLMEvaluator {
154156

155157
var running = false
156158

159+
var includeWeatherTool = false
160+
157161
var output = ""
158162
var modelInfo = ""
159163
var stat = ""
160164

161165
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
162166
/// more devices.
163-
let modelConfiguration = ModelRegistry.phi3_5_4bit
167+
// let modelConfiguration = ModelRegistry.phi3_5_4bit
168+
// let modelConfiguration = ModelRegistry.llama3_2_3B_4bit
169+
// let modelConfiguration = ModelRegistry.llama3_1_8B_4bit
170+
// let modelConfiguration = ModelRegistry.mistral7B4bit
171+
// let modelConfiguration = ModelRegistry.qwen2_5_7b
172+
let modelConfiguration = ModelRegistry.qwen2_5_1_5b
164173

165174
/// parameters controlling the output
166175
let generateParameters = GenerateParameters(temperature: 0.6)
@@ -178,6 +187,29 @@ class LLMEvaluator {
178187

179188
var loadState = LoadState.idle
180189

190+
let currentWeatherToolSpec: [String: any Sendable] =
191+
[
192+
"type": "function",
193+
"function": [
194+
"name": "get_current_weather",
195+
"description": "Get the current weather in a given location",
196+
"parameters": [
197+
"type": "object",
198+
"properties": [
199+
"location": [
200+
"type": "string",
201+
"description": "The city and state, e.g. San Francisco, CA",
202+
] as [String: String],
203+
"unit": [
204+
"type": "string",
205+
"enum": ["celsius", "fahrenheit"],
206+
] as [String: any Sendable],
207+
] as [String: [String: any Sendable]],
208+
"required": ["location"],
209+
] as [String: any Sendable],
210+
] as [String: any Sendable],
211+
] as [String: any Sendable]
212+
181213
/// load and return the model -- can be called multiple times, subsequent calls will
182214
/// just return the loaded model
183215
func load() async throws -> ModelContainer {
@@ -222,18 +254,22 @@ class LLMEvaluator {
222254
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
223255

224256
let result = try await modelContainer.perform { context in
225-
let input = try await context.processor.prepare(input: .init(prompt: prompt))
257+
let input = try await context.processor.prepare(
258+
input: .init(
259+
messages: [
260+
["role": "system", "content": "You are a helpful assistant."],
261+
["role": "user", "content": prompt],
262+
], tools: includeWeatherTool ? [currentWeatherToolSpec] : nil))
226263
return try MLXLMCommon.generate(
227264
input: input, parameters: generateParameters, context: context
228265
) { tokens in
229-
// update the output -- this will make the view show the text as it generates
266+
// Show the text in the view as it generates
230267
if tokens.count % displayEveryNTokens == 0 {
231268
let text = context.tokenizer.decode(tokens: tokens)
232269
Task { @MainActor in
233270
self.output = text
234271
}
235272
}
236-
237273
if tokens.count >= maxTokens {
238274
return .stop
239275
} else {

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,14 @@ public class ModelRegistry: @unchecked Sendable {
143143
defaultPrompt: "What is the difference between lettuce and cabbage?"
144144
)
145145

146-
static public let qwen205b4bit = ModelConfiguration(
147-
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
148-
overrideTokenizer: "PreTrainedTokenizer",
149-
defaultPrompt: "why is the sky blue?"
146+
static public let qwen2_5_7b = ModelConfiguration(
147+
id: "mlx-community/Qwen2.5-7B-Instruct-4bit",
148+
defaultPrompt: "Why is the sky blue?"
149+
)
150+
151+
static public let qwen2_5_1_5b = ModelConfiguration(
152+
id: "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
153+
defaultPrompt: "Why is the sky blue?"
150154
)
151155

152156
static public let openelm270m4bit = ModelConfiguration(
@@ -192,7 +196,8 @@ public class ModelRegistry: @unchecked Sendable {
192196
phi3_5MoE,
193197
phi3_5_4bit,
194198
phi4bit,
195-
qwen205b4bit,
199+
qwen2_5_7b,
200+
qwen2_5_1_5b,
196201
smolLM_135M_4bit,
197202
]
198203
}
@@ -229,7 +234,8 @@ private struct LLMUserInputProcessor: UserInputProcessor {
229234
func prepare(input: UserInput) throws -> LMInput {
230235
do {
231236
let messages = input.prompt.asMessages()
232-
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
237+
let promptTokens = try tokenizer.applyChatTemplate(
238+
messages: messages, tools: input.tools, additionalContext: input.additionalContext)
233239
return LMInput(tokens: MLXArray(promptTokens))
234240
} catch {
235241
// #150 -- it might be a TokenizerError.chatTemplate("No chat template was specified")

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import CoreImage
44
import Foundation
55
import MLX
6+
import Tokenizers
67

78
/// Container for raw user input.
89
///
@@ -108,23 +109,41 @@ public struct UserInput: Sendable {
108109
}
109110

110111
public var prompt: Prompt
112+
public var tools: [ToolSpec]?
113+
/// Additional values provided for the chat template rendering context
114+
public var additionalContext: [String: Any]?
111115
public var images = [Image]()
112116
public var processing: Processing = .init()
113117

114-
public init(prompt: String, images: [Image] = [Image]()) {
118+
public init(
119+
prompt: String, images: [Image] = [Image](), tools: [ToolSpec]? = nil,
120+
additionalContext: [String: Any]? = nil
121+
) {
115122
self.prompt = .text(prompt)
116123
self.images = images
124+
self.tools = tools
125+
self.additionalContext = additionalContext
117126
}
118127

119-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
128+
public init(
129+
messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil,
130+
additionalContext: [String: Any]? = nil
131+
) {
120132
self.prompt = .messages(messages)
121133
self.images = images
134+
self.tools = tools
135+
self.additionalContext = additionalContext
122136
}
123137

124-
public init(prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init()) {
138+
public init(
139+
prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init(),
140+
tools: [ToolSpec]? = nil, additionalContext: [String: Any]? = nil
141+
) {
125142
self.prompt = prompt
126143
self.images = images
127144
self.processing = processing
145+
self.tools = tools
146+
self.additionalContext = additionalContext
128147
}
129148
}
130149

0 commit comments

Comments
 (0)