import sys
import matplotlib.pyplot as plt
import mesa.experimental.components.matplotlib as components_matplotlib
import solara
import xyzservices.providers as xyz
from mesa.experimental import jupyter_viz as jv
from solara.alias import rv
import mesa_geo.visualization.leaflet_viz as leaflet_viz
# Avoid interactive backend
plt.switch_backend("agg")
# TODO: Turn this function into a Solara component once the current_step.value
# dependency is passed to measure()
"""
Geo-Mesa Visualization Module
=============================
<<<<<<< HEAD
Card: Helper Function that initiates the Solara Card for Browser
GeoJupyterViz: Main Function users employ to create visualization
=======
Card: Helper Function that initiates the Solara Card for Browser
GeoJupyterViz: Main Function users employ to create visualization
>>>>>>> main
"""
[docs]
def Card(
model,
measures,
agent_portrayal,
map_drawer,
center_default,
zoom,
current_step,
color,
layout_type,
):
with rv.Card(
style_=f"background-color: {color}; width: 100%; height: 100%"
) as main:
if "Map" in layout_type:
rv.CardTitle(children=["Map"])
leaflet_viz.map(model, map_drawer, zoom, center_default)
if "Measure" in layout_type:
rv.CardTitle(children=["Measure"])
measure = measures[layout_type["Measure"]]
if callable(measure):
# Is a custom object
measure(model)
else:
components_matplotlib.PlotMatplotlib(
model, measure, dependencies=[current_step.value]
)
return main
@solara.component
def GeoJupyterViz(
model_class,
model_params,
measures=None,
name=None,
agent_portrayal=None,
play_interval=150,
# parameters for leaflet_viz
view=None,
zoom=None,
scroll_wheel_zoom=True,
tiles=xyz.OpenStreetMap.Mapnik,
center_point=None, # Due to projection challenges in calculation allow user to specify center point
):
"""Initialize a component to visualize a model.
Args:
model_class: class of the model to instantiate
model_params: parameters for initializing the model
measures: list of callables or data attributes to plot
name: name for display
agent_portrayal: options for rendering agents (dictionary)
space_drawer: method to render the agent space for
the model; default implementation is the `SpaceMatplotlib` component;
simulations with no space to visualize should
specify `space_drawer=False`
play_interval: play interval (default: 150)
center_point: list of center coords
"""
if name is None:
name = model_class.__name__
current_step = solara.use_reactive(0)
# 1. Set up model parameters
user_params, fixed_params = jv.split_model_params(model_params)
model_parameters, set_model_parameters = solara.use_state(
{**fixed_params, **{k: v.get("value") for k, v in user_params.items()}}
)
# 2. Set up Model
def make_model():
model = model_class(**model_parameters)
current_step.value = 0
return model
reset_counter = solara.use_reactive(0)
model = solara.use_memo(
make_model, dependencies=[*list(model_parameters.values()), reset_counter.value]
)
def handle_change_model_params(name: str, value: any):
set_model_parameters({**model_parameters, name: value})
# 3. Set up UI
with solara.AppBar():
solara.AppBarTitle(name)
# 4. Set Up Map
# render layout, pass through map build parameters
map_drawer = leaflet_viz.MapModule(
portrayal_method=agent_portrayal,
view=view,
zoom=zoom,
tiles=tiles,
scroll_wheel_zoom=scroll_wheel_zoom,
)
layers = map_drawer.render(model)
# determine center point
if center_point:
print("None Called")
center_default = center_point
else:
bounds = layers["layers"]["total_bounds"]
center_default = [
(bounds[0][0] + bounds[1][0]) / 2,
(bounds[0][1] + bounds[1][1]) / 2,
]
def render_in_jupyter():
# TODO: Build API to allow users to set rows and columns
# call in property of model layers geospace line; use 1 column to prevent map overlap
with solara.Row(
justify="space-between", style={"flex-grow": "1"}
) and solara.GridFixed(columns=2):
jv.UserInputs(user_params, on_change=handle_change_model_params)
jv.ModelController(model, play_interval, current_step, reset_counter)
solara.Markdown(md_text=f"###Step - {current_step}")
# Builds Solara component of map
leaflet_viz.map_jupyter(
model, map_drawer, zoom, center_default, scroll_wheel_zoom
)
# Place measurement in separate row
with solara.Row(
justify="space-between",
style={"flex-grow": "1"},
):
# 5. Plots
for measure in measures:
if callable(measure):
# Is a custom object
measure(model)
else:
components_matplotlib.PlotMatplotlib(
model, measure, dependencies=[current_step.value]
)
def render_in_browser():
# determine center point
if center_point:
center_default = center_point
else:
bounds = layers["layers"]["total_bounds"]
center_default = list((bounds[2:] + bounds[:2]) / 2)
# if space drawer is disabled, do not include it
layout_types = [{"Map": "default"}]
if measures:
layout_types += [{"Measure": elem} for elem in range(len(measures))]
grid_layout_initial = jv.make_initial_grid_layout(layout_types=layout_types)
grid_layout, set_grid_layout = solara.use_state(grid_layout_initial)
with solara.Sidebar():
with solara.Card("Controls", margin=1, elevation=2):
jv.UserInputs(user_params, on_change=handle_change_model_params)
jv.ModelController(model, play_interval, current_step, reset_counter)
with solara.Card("Progress", margin=1, elevation=2):
solara.Markdown(md_text=f"####Step - {current_step}")
items = [
Card(
model,
measures,
agent_portrayal,
map_drawer,
center_default,
zoom,
current_step,
color="white",
layout_type=layout_types[i],
)
for i in range(len(layout_types))
]
solara.GridDraggable(
items=items,
grid_layout=grid_layout,
resizable=True,
draggable=True,
on_grid_layout=set_grid_layout,
)
if ("ipykernel" in sys.argv[0]) or ("colab_kernel_launcher.py" in sys.argv[0]):
# When in Jupyter or Google Colab
render_in_jupyter()
else:
render_in_browser()