Getting underlying ODEs in Jaxley #699
Replies: 4 comments 2 replies
-
Hi! Thanks for reaching out!
If you are trying to use other channel models, check this guide or our library for channel models. You can also implement other models (e.g. Izhikevich). This is described here. Hope this helps! |
Beta Was this translation helpful? Give feedback.
-
Hi, we would be particularly interested to use the auto-differentiationl capabilities of Jaxley to calculate the Jacobian of the underlying ODEs w.r.t. state variables. The Jaxley paper [1] shows that this is possible for network states (Fig S6), but we could not figure out how to do it. Thx for the great work! [1] /10.1101/2024.08.21.608979 |
Beta Was this translation helpful? Give feedback.
-
Hi! No, it is not possible to print the equations. @Matthijspals could you answer the question regarding the Jacobians? |
Beta Was this translation helpful? Give feedback.
-
Hi, regarding the Jacobians, the code for the supplementary figure is available. However indexing is slightly different in the current version of Jaxley. I think the current easiest version to get a Jacobian is to adapt the def step_fn_vec_to_vec(flat):
states = unravel_fn(flat)
states = step_fn(states, params, {}, delta_t)
return ravel_pytree(states)[0] Here is a full example, using an example cell with 10 branches, each with 3 compartments. Every compartment has a leak channel, and one branch has a sodium channel. import jaxley as jx
from jaxley.integrate import build_init_and_step_fn
from jax.flatten_util import ravel_pytree
from jax import jacfwd, jit
from jaxley.channels import Na, Leak
import numpy as np
# Simulation parameters
n_comps=3
delta_t = 0.025
n_steps = 1
# Load the cell model and insert channels
cell = jx.Cell()
cell = jx.read_swc("../jaxley/tests/swc_files/morph_ca1_n120_250_single_point_soma.swc", ncomp=n_comps)
#visualise cell
cell.vis()
cell.insert(Leak())
cell.branch(0).insert(Na())
# set random initial voltages to get dynamics
for i in range(10):
for j in range(n_comps):
cell.branch(i).comp(j).set("v", -70+10*np.random.rand())
cell.record()
params = cell.get_parameters()
cell.to_jax()
rec_inds = cell.recordings.rec_index.to_numpy()
rec_states = cell.recordings.state.to_numpy()
# Initialize.
init_fn, step_fn = build_init_and_step_fn(cell,solver="bwd_euler")
states, params = init_fn(params)
# For the jacobian we need a step function that takes in a vector, we can use ravel_pytree for this.
# unravel_fn stays the same even if the states change, so we can use it to reconstruct the states later.
flat, unravel_fn = ravel_pytree(states)
print(len(flat)-np.sum(np.isnan(flat)))
# Loop over the ODE. The `step_fn` can be jitted for improving speed.
@jit
def step_fn_vec_to_vec(flat):
states = unravel_fn(flat)
states = step_fn(states, params, {}, delta_t)
return ravel_pytree(states)[0]
def clean_Jac(Jac, flat):
"""Remove row/columns from Jac corresponding to flat entries with NaN values."""
mask = ~np.isnan(flat)
return Jac[mask][:, mask]
Jacs = []
for step in range(n_steps):
# the next two lines can be probably be done more efficiently together (jax.linearize?)
flat = step_fn_vec_to_vec(flat)
jac = jacfwd(step_fn_vec_to_vec)(flat)
print(len(flat)-np.sum(np.isnan(flat)))
# you can access the states like this
states = unravel_fn(flat)
#Store the jacobians
Jacs.append(clean_Jac(jac,flat)) Note that this includes entries for the branch-points, as well as entries for the currents. The latter you most likely want to remove, not 100% sure about the branch points? (@michaeldeistler ?). I think an additional |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I've been running some simulations using Jaxley and I'm curious about the underlying differential equations it uses. Where would I be able to find them? Also, is it possible to modify them to simulate using different models?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions