@@ -10,10 +10,10 @@ import SwiftUI
10
10
import Tokenizers
11
11
12
12
struct ContentView : View {
13
+ @Environment ( DeviceStat . self) private var deviceStat
13
14
14
- @State var prompt = " "
15
15
@State var llm = LLMEvaluator ( )
16
- @Environment ( DeviceStat . self ) private var deviceStat
16
+ @State var prompt = " What's the current weather in Paris? "
17
17
18
18
enum displayStyle : String , CaseIterable , Identifiable {
19
19
case plain, markdown
@@ -34,6 +34,9 @@ struct ContentView: View {
34
34
Text ( llm. stat)
35
35
}
36
36
HStack {
37
+ Toggle ( isOn: $llm. includeWeatherTool) {
38
+ Text ( " Include \" get current weather \" tool " )
39
+ }
37
40
Spacer ( )
38
41
if llm. running {
39
42
ProgressView ( )
@@ -126,8 +129,6 @@ struct ContentView: View {
126
129
127
130
}
128
131
. task {
129
- self . prompt = llm. modelConfiguration. defaultPrompt
130
-
131
132
// pre-load the weights on launch to speed up the first generation
132
133
_ = try ? await llm. load ( )
133
134
}
@@ -154,13 +155,19 @@ class LLMEvaluator {
154
155
155
156
var running = false
156
157
158
+ var includeWeatherTool = false
159
+
157
160
var output = " "
158
161
var modelInfo = " "
159
162
var stat = " "
160
163
161
164
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
162
165
/// more devices.
163
- let modelConfiguration = ModelRegistry . phi3_5_4bit
166
+ // let modelConfiguration = ModelRegistry.phi3_5_4bit
167
+ // let modelConfiguration = ModelRegistry.llama3_2_3B_4bit
168
+ // let modelConfiguration = ModelRegistry.llama3_1_8B_4bit
169
+ // let modelConfiguration = ModelRegistry.mistral7B4bit
170
+ let modelConfiguration = ModelRegistry . qwen2_5_7b
164
171
165
172
/// parameters controlling the output
166
173
let generateParameters = GenerateParameters ( temperature: 0.6 )
@@ -178,6 +185,29 @@ class LLMEvaluator {
178
185
179
186
var loadState = LoadState . idle
180
187
188
+ let currentWeatherToolSpec : [ String : any Sendable ] =
189
+ [
190
+ " type " : " function " ,
191
+ " function " : [
192
+ " name " : " get_current_weather " ,
193
+ " description " : " Get the current weather in a given location " ,
194
+ " parameters " : [
195
+ " type " : " object " ,
196
+ " properties " : [
197
+ " location " : [
198
+ " type " : " string " ,
199
+ " description " : " The city and state, e.g. San Francisco, CA " ,
200
+ ] as [ String : String ] ,
201
+ " unit " : [
202
+ " type " : " string " ,
203
+ " enum " : [ " celsius " , " fahrenheit " ] ,
204
+ ] as [ String : any Sendable ] ,
205
+ ] as [ String : [ String : any Sendable ] ] ,
206
+ " required " : [ " location " ] ,
207
+ ] as [ String : any Sendable ] ,
208
+ ] as [ String : any Sendable ] ,
209
+ ] as [ String : any Sendable ]
210
+
181
211
/// load and return the model -- can be called multiple times, subsequent calls will
182
212
/// just return the loaded model
183
213
func load( ) async throws -> ModelContainer {
@@ -222,18 +252,22 @@ class LLMEvaluator {
222
252
MLXRandom . seed ( UInt64 ( Date . timeIntervalSinceReferenceDate * 1000 ) )
223
253
224
254
let result = try await modelContainer. perform { context in
225
- let input = try await context. processor. prepare ( input: . init( prompt: prompt) )
255
+ let input = try await context. processor. prepare (
256
+ input: . init(
257
+ messages: [
258
+ [ " role " : " system " , " content " : " You are a helpful assistant. " ] ,
259
+ [ " role " : " user " , " content " : prompt] ,
260
+ ] , tools: includeWeatherTool ? [ currentWeatherToolSpec] : nil ) )
226
261
return try MLXLMCommon . generate (
227
262
input: input, parameters: generateParameters, context: context
228
263
) { tokens in
229
- // update the output -- this will make the view show the text as it generates
264
+ // Show the text in the view as it generates
230
265
if tokens. count % displayEveryNTokens == 0 {
231
266
let text = context. tokenizer. decode ( tokens: tokens)
232
267
Task { @MainActor in
233
268
self . output = text
234
269
}
235
270
}
236
-
237
271
if tokens. count >= maxTokens {
238
272
return . stop
239
273
} else {
0 commit comments