@@ -10,10 +10,10 @@ import SwiftUI
10
10
import Tokenizers
11
11
12
12
struct ContentView : View {
13
+ @Environment ( DeviceStat . self) private var deviceStat
13
14
14
- @State var prompt = " "
15
15
@State var llm = LLMEvaluator ( )
16
- @Environment ( DeviceStat . self ) private var deviceStat
16
+ @State var prompt = " What's the current weather in Paris? "
17
17
18
18
enum displayStyle : String , CaseIterable , Identifiable {
19
19
case plain, markdown
@@ -34,6 +34,10 @@ struct ContentView: View {
34
34
Text ( llm. stat)
35
35
}
36
36
HStack {
37
+ Toggle ( isOn: $llm. includeWeatherTool) {
38
+ Text ( " Include \" get current weather \" tool " )
39
+ }
40
+ . frame ( maxWidth: 350 , alignment: . leading)
37
41
Spacer ( )
38
42
if llm. running {
39
43
ProgressView ( )
@@ -126,8 +130,6 @@ struct ContentView: View {
126
130
127
131
}
128
132
. task {
129
- self . prompt = llm. modelConfiguration. defaultPrompt
130
-
131
133
// pre-load the weights on launch to speed up the first generation
132
134
_ = try ? await llm. load ( )
133
135
}
@@ -154,13 +156,20 @@ class LLMEvaluator {
154
156
155
157
var running = false
156
158
159
+ var includeWeatherTool = false
160
+
157
161
var output = " "
158
162
var modelInfo = " "
159
163
var stat = " "
160
164
161
165
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
162
166
/// more devices.
163
- let modelConfiguration = ModelRegistry . phi3_5_4bit
167
+ // let modelConfiguration = ModelRegistry.phi3_5_4bit
168
+ // let modelConfiguration = ModelRegistry.llama3_2_3B_4bit
169
+ // let modelConfiguration = ModelRegistry.llama3_1_8B_4bit
170
+ // let modelConfiguration = ModelRegistry.mistral7B4bit
171
+ // let modelConfiguration = ModelRegistry.qwen2_5_7b
172
+ let modelConfiguration = ModelRegistry . qwen2_5_1_5b
164
173
165
174
/// parameters controlling the output
166
175
let generateParameters = GenerateParameters ( temperature: 0.6 )
@@ -178,6 +187,29 @@ class LLMEvaluator {
178
187
179
188
var loadState = LoadState . idle
180
189
190
+ let currentWeatherToolSpec : [ String : any Sendable ] =
191
+ [
192
+ " type " : " function " ,
193
+ " function " : [
194
+ " name " : " get_current_weather " ,
195
+ " description " : " Get the current weather in a given location " ,
196
+ " parameters " : [
197
+ " type " : " object " ,
198
+ " properties " : [
199
+ " location " : [
200
+ " type " : " string " ,
201
+ " description " : " The city and state, e.g. San Francisco, CA " ,
202
+ ] as [ String : String ] ,
203
+ " unit " : [
204
+ " type " : " string " ,
205
+ " enum " : [ " celsius " , " fahrenheit " ] ,
206
+ ] as [ String : any Sendable ] ,
207
+ ] as [ String : [ String : any Sendable ] ] ,
208
+ " required " : [ " location " ] ,
209
+ ] as [ String : any Sendable ] ,
210
+ ] as [ String : any Sendable ] ,
211
+ ] as [ String : any Sendable ]
212
+
181
213
/// load and return the model -- can be called multiple times, subsequent calls will
182
214
/// just return the loaded model
183
215
func load( ) async throws -> ModelContainer {
@@ -222,18 +254,22 @@ class LLMEvaluator {
222
254
MLXRandom . seed ( UInt64 ( Date . timeIntervalSinceReferenceDate * 1000 ) )
223
255
224
256
let result = try await modelContainer. perform { context in
225
- let input = try await context. processor. prepare ( input: . init( prompt: prompt) )
257
+ let input = try await context. processor. prepare (
258
+ input: . init(
259
+ messages: [
260
+ [ " role " : " system " , " content " : " You are a helpful assistant. " ] ,
261
+ [ " role " : " user " , " content " : prompt] ,
262
+ ] , tools: includeWeatherTool ? [ currentWeatherToolSpec] : nil ) )
226
263
return try MLXLMCommon . generate (
227
264
input: input, parameters: generateParameters, context: context
228
265
) { tokens in
229
- // update the output -- this will make the view show the text as it generates
266
+ // Show the text in the view as it generates
230
267
if tokens. count % displayEveryNTokens == 0 {
231
268
let text = context. tokenizer. decode ( tokens: tokens)
232
269
Task { @MainActor in
233
270
self . output = text
234
271
}
235
272
}
236
-
237
273
if tokens. count >= maxTokens {
238
274
return . stop
239
275
} else {
0 commit comments