Building foveated deep vision models based on kNN-convolution
Next, we will dive into how to do perceptual processing of our foveated sensor outputs, making use of the sensor manifold.
For this, we use k-nearest-neighbor (kNN) receptive fields. In 2-D, receptive fields are specified as \((h,w)\) rectangular grids; on our 3-D manifold, they are specified as kNNs.
The details are described in the paper. Here, we will go through the relevant code modules to see how to build up networks based on kNN-convolution on the foveated sensor manifold
We will now start looking at fovi.arch, where all of the architectural features relevant to foveated perceptual processing live.
The building block layers are stored in fovi.arch.knn. Let’s use them to build a super simple 1-layer convolutional network, with a convolution layer followed by a pooling layer, normalization, and ReLU.
[1]:
%load_ext autoreload
%autoreload 2
from fovi.arch.knn import KNNPoolingLayer, KNNConvLayer, get_in_out_coords
from fovi.arch.norm import KNNBatchNorm
import torch.nn as nn
fov = 16
cmf_a = 0.5
device = 'cpu'
cartesian_res = 64
conv_kernel_cartesian = [7, 7]
pool_kernel_cartesian = [3, 3]
conv_stride = 1
pool_stride = 2
channels = 128
# determine neighborhood sizes based on target cartesian kernels
k_conv = conv_kernel_cartesian[0]*conv_kernel_cartesian[1]
k_pool = pool_kernel_cartesian[0]*pool_kernel_cartesian[1]
# set up coordinates based on input resolution and strides
in_cart_res = cartesian_res
sensor_coords, conv_coords, out_cart_res = get_in_out_coords(in_cart_res, fov, cmf_a, conv_stride, in_cart_res=in_cart_res, device=device)
# the previous layer is the input to the next layer
in_cart_res = out_cart_res
_, pool_coords, _ = get_in_out_coords(in_cart_res, fov, cmf_a, pool_stride, in_cart_res=in_cart_res, in_coords=conv_coords, device=device)
conv_layer = KNNConvLayer(3, channels, k_conv, sensor_coords, conv_coords,
ref_frame_side_length=2*conv_kernel_cartesian[0],
device=device,
)
pool_layer = KNNPoolingLayer(k_pool, conv_coords, pool_coords, mode='max', device=device)
full_layer = nn.Sequential(
conv_layer,
pool_layer,
nn.ReLU(),
KNNBatchNorm(len(pool_coords), channels, device=device),
).to(device)
/n/home12/nblauch/.conda/envs/new_workshop/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
note that the “auto_match_cart_resources=1” attempts to match resources as closely as possible, but cannot be perfect (see print out). the default behavior is to ensure that we select less than or equal to the cartesian equivalent resources.
ok!
let’s create some fake data and pass it through our simple fovi layer
[2]:
import torch
# Our foveated sensor outputs and KNNLayer inputs/outputs are formatted as [batch, num_channels, num_coords]
x_sensor = torch.rand([64, 3, len(sensor_coords)]).to(device)
layer_output = full_layer(x_sensor)
[3]:
layer_output.shape
[3]:
torch.Size([64, 128, 964])
In practice, the KNNAlexNetBlock is set up to do exactly what we just did: combine conv, pooling, nonlinearity, and normalization.
[4]:
from fovi.arch.knnalexnet import KNNAlexNetBlock
block = KNNAlexNetBlock(3, channels, k_conv, fov, cmf_a, cartesian_res, conv_stride, cart_res=cartesian_res, pool=True, pool_k=k_pool, pool_stride=pool_stride, norm_type='batch', auto_match_cart_resources=1, device=device, ref_frame_mult=2)
[5]:
block_output = block(x_sensor)
block_output.shape
[5]:
torch.Size([64, 128, 964])
Building an AlexNet-like KNN model
Now that we’ve seen the layers and blocks, we are ready to build a complete KNNAlexNet model. For this, we will just use our wrapper function, and refer you to the code for further detail. The KNNAlexNet class builds AlexNet-like models, but is more flexible to different numbers of layers, different kernel sizes, channels dimensions, etc.
[6]:
from fovi.arch.knnalexnet import KNNAlexNet
model = KNNAlexNet(
cartesian_res,
3,
[96, 256, 384, 256, 256], # channels per layer
[4,1,1,1,1], # conv stride per layer
[1,4], # pool after
[11**2, 5**2, 3**2, 3**3, 3**2], # k per layer
n_classes=1000,
out_res=None, # no output pooling
auto_match_cart_resources=1,
norm_type='batch',
ref_frame_mult=2,
fov=fov,
cmf_a=cmf_a,
device=device,
)
no output pooling layer
[7]:
model
[7]:
KNNAlexNet(
(layers): ModuleList(
(0): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=3
out_channels=96
k=121
n_ref=484
in_coords=SamplingCoords(length=4085, fov=16, cmf_a=0.5, resolution=53, style=isotropic)
out_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(1): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=96
out_channels=256
k=25
n_ref=100
in_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
out_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): KNNPoolingLayer(
mode=max
k=9
in_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
)
(2): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=256
out_channels=384
k=9
n_ref=36
in_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(3): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=384
out_channels=256
k=27
n_ref=121
in_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(4): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=256
out_channels=256
k=9
n_ref=36
in_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): KNNPoolingLayer(
mode=max
k=9
in_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=16, fov=16, cmf_a=0.5, resolution=4, style=isotropic)
sample_cortex=True
)
)
)
(classifier): Linear(in_features=4096, out_features=1000, bias=True)
)
Building a complete FoviNet
You may have noticed that we have only worked with fake data thus far. Our KNN architectures are designed to work with outputs formatted on the sensor manifold. To get this from images, we need to use a RetinalTransform object, or more simply, a SaccadePolicy which will also determine our fixations for us. If you forget about these, go back to step1_sampling.ipynb.
The last piece of the puzzle is the FoviNet class which combines a fixation policy with a processing network. Since there are a lot of hyperparameters, here, we specify them all in a neat hierarchical config. This is typically specified as a .yaml file, and we use hydra and omega to handle these in our training scripts. Because the config uses inheritance, we initialize and compose it with hydra.
[8]:
from fovi.fovinet import FoviNet
from hydra import compose, initialize
# Use hydra/omega to process the hierarchical config, including all defaults
with initialize(version_base=None, config_path="../config"):
config = compose(config_name="fovi_alexnet.yaml")
print(type(config)) # OmegaConf DictConfig
fovinet = FoviNet(config, device=device)
<class 'omegaconf.dictconfig.DictConfig'>
adjusting FOV for fixation: 16.0 (full: 16.0)
/n/home12/nblauch/git/foveation-private/fovi/sensing/coords.py:349: RuntimeWarning: divide by zero encountered in scalar divide
w_delta = (w_max - w_min)/(res-1)
Note: horizontal flip always done in the loader, to avoid differences across fixations
Number of coords per layer: [4085, 964, 230, 230, 60, 60, 60, 60, 16, 1]
[9]:
fovinet
[9]:
FoviNet(
(network): BackboneProjectorWrapper(
(backbone): KNNAlexNet(
(layers): ModuleList(
(0): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=3
out_channels=96
k=121
n_ref=484
in_coords=SamplingCoords(length=4085, fov=16.0, cmf_a=0.5, resolution=53, style=isotropic)
out_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): KNNPoolingLayer(
mode=max
k=9
in_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)
out_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)
sample_cortex=True
)
)
(1): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=96
out_channels=256
k=25
n_ref=100
in_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)
out_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): KNNPoolingLayer(
mode=max
k=9
in_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
)
(2): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=256
out_channels=384
k=9
n_ref=36
in_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(3): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=384
out_channels=384
k=9
n_ref=36
in_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(4): KNNAlexNetBlock(
(conv): KNNConvLayer(
in_channels=384
out_channels=256
k=9
n_ref=36
in_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
sample_cortex=True
)
(norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): KNNPoolingLayer(
mode=max
k=9
in_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)
out_coords=SamplingCoords(length=16, fov=16.0, cmf_a=0.5, resolution=4, style=isotropic)
sample_cortex=True
)
)
(5): KNNPoolingLayer(
mode=avg
k=16
in_coords=SamplingCoords(length=16, fov=16.0, cmf_a=0.5, resolution=4, style=isotropic)
out_coords=SamplingCoords(length=1, fov=16.0, cmf_a=0.5, resolution=1, style=isotropic)
sample_cortex=True
)
)
)
(projector): MLPWrapper(
(layers): Sequential(
(fc_block_6): LayerBlock(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=256, out_features=1024, bias=True)
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(fc_block_7): LayerBlock(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=1024, out_features=1024, bias=False)
)
)
)
)
(retinal_transform): RetinalTransform(
(foveal_color): GaussianColorDecay(sigma=None)
(sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
)
(ssl_fixator): NoSaccadePolicy(
retinal_transform=RetinalTransform(
(foveal_color): GaussianColorDecay(sigma=None)
(sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
),
n_fixations=1
)
(sup_fixator): MultiRandomSaccadePolicy(
retinal_transform=RetinalTransform(
(foveal_color): GaussianColorDecay(sigma=None)
(sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)
),
n_fixations=4,
nonrandom_first=False,
nonrandom_val=False,
crop_area_range=[1.0, 1.0],
add_aspect_variation=None,
val_crop_size=1.0,
norm_dist_from_center=0.25
)
(head): FoviNetProbe(
(fix_projector): LinearProbe(
(dropout): Dropout(p=0.5, inplace=False)
(probe): Linear(in_features=1024, out_features=1000, bias=True)
)
)
)
We can now process image data with our fovinet model
[10]:
from fovi.demo import get_image_as_batch
# data = torch.rand([128, 3, 256, 256]).to(device)
data = get_image_as_batch(device=device)
print(data.shape)
category_logits, layer_outputs, x_fixs = fovinet(data)
torch.Size([1, 3, 256, 256])
[11]:
# global avg pool of final conv layer
print(layer_outputs[0].shape) # (batch, num_fixations, conv_dim)
# final MLP layer:
print(layer_outputs[-1].shape) # (batch, num_fixations, fc_dim)
# category logits averaged across fixations
print(category_logits.shape) # (batch, num_classes)
torch.Size([1, 4, 256])
torch.Size([1, 4, 1024])
torch.Size([1, 1000])
load a pre-trained model
[12]:
from fovi import load_config
base_fn = 'fovi-alexnet_a-1_res-64_in1k'
config, state_dict, model_key = load_config(base_fn, load=True, folder='../models', device='cpu')
model = FoviNet(config, device='cpu')
model.load_state_dict(state_dict[model_key])
adjusting FOV for fixation: 16.0 (full: 16.0)
Note: horizontal flip always done in the loader, to avoid differences across fixations
Number of coords per layer: [4085, 964, 230, 230, 60, 60, 60, 60, 16, 1]
[12]:
<All keys matched successfully>