Static image¶
This is an example of applying retinomorphic convolution to a static image. It also demonstrates how the different layers are built and shows some plots of cells and their receptive fields for each layer.
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import skimage as ski
from pathlib import Path
from pyrception.utils import plot
from pyrception import logger
from pyrception.visual.rf import ReceptiveFields
from pyrception.visual.layers import (
ReceptorLayer,
HorizontalLayer,
BipolarLayer,
AmacrineLayer,
GanglionLayer,
)
from pyrception.utils.enums import RFArrangement, KernelFilter, KernelShape
Open an image¶
img_path = Path("./resources/cat_in_snow.jpg")
img = ski.io.imread(img_path, as_gray=True)
img = ski.transform.resize(img, np.array(img.shape) // 4, anti_aliasing=True)
size = img.shape
Visualising the output of retinal layers¶
First, show the image that we are working with.
plot.show_composite(plot.image(img, title="Original image (greyscale)"))
Receptor layer¶
This layer is a simple proxy for the first layer of cells in the retina: photoreceptors.
rc_layer = ReceptorLayer(size)
Pyrception | ReceptorLayer | Initialised.
Horizontal layer¶
Horizontal cells compute the mean local brightness $\tilde{I}$ of the patch of the visual field corresponding to their respective receptive fields, as the average of the light intensities $I$ signalled by the photoreceptors feeding into those receptive fields. The receptive fields overlap (see the plot below), so a single photoreceptor might contribute to the activation of more than one horizontal cell. Each horizontal cell sends the mean brightness as a feedback signal to all photoreceptors within its receptive field. As a result, the output signal $I^{\prime}$ of photoreceptors is the deviation of the raw photoreceptor activation $I$ from the mean intensity $\tilde{I}$:
$$ I^{\prime} = I - \tilde{I} $$
hz_rfs = ReceptiveFields(size, sectors=64, kbounds=(3, 3))
hz_layer = HorizontalLayer(hz_rfs)
Building kernels...: 100%|██████████| 2687/2687 [00:00<00:00, 3035.90it/s] Pyrception | HorizontalLayer | Initialised.
Let's visualise the receptive fields of the horizontal cells.
hz_rfs.visualise()
rbegin = 30
rend = 31
hz_rfs.visualise(hz_rfs.cell_rings[rbegin:rend], title=f"Horizontal cells | Receptive fields - rings {rbegin}-{rend}")
sbegin = 35
send = 36
hz_rfs.visualise(hz_rfs.cell_sectors[sbegin:send], title=f"Horizontal cells | Receptive fields - sectors {sbegin}-{send}")
rc_activations = rc_layer(img)
(hz_activations, hz_feedback) = hz_layer(rc_layer.activations)
hz_layer.visualise_activations()
hz_layer.visualise_feedback()
Process the image with the horizontal layer and plot the activations as a scatter plot.
hz_layer.plot_activations()
Compute the normalised version of the raw input by subtracting the mean illumination map. The normalised input is then passed to the bipolar layer.
hz_layer.visualise_norm_input(img)
Bipolar layer¶
bp_rfs = ReceptiveFields(size=size, sectors=128, substrate=rc_layer.substrate)
bp_layer = BipolarLayer(bp_rfs)
Building kernels...: 100%|██████████| 8711/8711 [00:02<00:00, 3422.72it/s] Pyrception | BipolarLayer | Initialised.
bp_rfs.visualise(outline_colour=None, kernel_colour=None)
bl_on, bl_off = bp_layer(rc_activations, hz_feedback)
bp_layer.visualise_activations()
Amacrine layer¶
am_rfs= ReceptiveFields(size, substrate=bp_rfs.cell_coordinates, kscale=2.0)
am_layer = AmacrineLayer(am_rfs)
Building kernels...: 100%|██████████| 2687/2687 [00:01<00:00, 2354.86it/s] Pyrception | AmacrineLayer | Initialised.
am_rfs.visualise(kernel_colour="#ff00ffff", cells=2400, outline_colour=None)
am_on = am_layer(bl_on)
am_layer.visualise_activations()
Ganglion layer¶
center = ReceptiveFields(size=size, substrate=bp_rfs.cell_coordinates)
surround = ReceptiveFields(size=size, substrate=am_rfs.cell_coordinates, kscale=2.0)
gl_layer = GanglionLayer(center, surround)
Building kernels...: 100%|██████████| 2687/2687 [00:00<00:00, 3172.35it/s] Building kernels...: 100%|██████████| 2687/2687 [00:01<00:00, 2324.81it/s] Pyrception | GanglionLayer | Initialised.
gl_out = gl_layer.forward(bl_on, am_on)
gl_layer.visualise_spikes()
logger.info(f"Active ganglion cells: {gl_out.sum()} / {len(gl_out)}")
Pyrception | Active ganglion cells: 651 / 2687