{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Building foveated deep vision models based on kNN-convolution\n", "\n", "Next, we will dive into how to do perceptual processing of our foveated sensor outputs, making use of the sensor manifold.\n", "\n", "For this, we use k-nearest-neighbor (kNN) receptive fields. In 2-D, receptive fields are specified as $(h,w)$ rectangular grids; on our 3-D manifold, they are specified as kNNs. \n", "\n", "The details are described in the paper. Here, we will go through the relevant code modules to see how to build up networks based on kNN-convolution on the foveated sensor manifold\n", "\n", "We will now start looking at `fovi.arch`, where all of the architectural features relevant to foveated perceptual processing live. \n", "\n", "The building block layers are stored in `fovi.arch.knn`. Let's use them to build a super simple 1-layer convolutional network, with a convolution layer followed by a pooling layer, normalization, and ReLU." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/n/home12/nblauch/.conda/envs/new_workshop/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n", " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from fovi.arch.knn import KNNPoolingLayer, KNNConvLayer, get_in_out_coords\n", "from fovi.arch.norm import KNNBatchNorm\n", "import torch.nn as nn\n", "\n", "fov = 16\n", "cmf_a = 0.5\n", "device = 'cpu'\n", "\n", "cartesian_res = 64\n", "conv_kernel_cartesian = [7, 7]\n", "pool_kernel_cartesian = [3, 3]\n", "conv_stride = 1\n", "pool_stride = 2\n", "channels = 128\n", "\n", "# determine neighborhood sizes based on target cartesian kernels\n", "k_conv = conv_kernel_cartesian[0]*conv_kernel_cartesian[1]\n", "k_pool = pool_kernel_cartesian[0]*pool_kernel_cartesian[1]\n", "\n", "# set up coordinates based on input resolution and strides\n", "in_cart_res = cartesian_res\n", "sensor_coords, conv_coords, out_cart_res = get_in_out_coords(in_cart_res, fov, cmf_a, conv_stride, in_cart_res=in_cart_res, device=device)\n", "# the previous layer is the input to the next layer\n", "in_cart_res = out_cart_res\n", "_, pool_coords, _ = get_in_out_coords(in_cart_res, fov, cmf_a, pool_stride, in_cart_res=in_cart_res, in_coords=conv_coords, device=device)\n", "\n", "conv_layer = KNNConvLayer(3, channels, k_conv, sensor_coords, conv_coords, \n", " ref_frame_side_length=2*conv_kernel_cartesian[0], \n", " device=device,\n", " )\n", "pool_layer = KNNPoolingLayer(k_pool, conv_coords, pool_coords, mode='max', device=device)\n", "\n", "full_layer = nn.Sequential(\n", " conv_layer,\n", " pool_layer,\n", " nn.ReLU(),\n", " KNNBatchNorm(len(pool_coords), channels, device=device),\n", ").to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "note that the \"auto_match_cart_resources=1\" attempts to match resources as closely as possible, but cannot be perfect (see print out). the default behavior is to ensure that we select less than or equal to the cartesian equivalent resources.\n", "\n", "ok!\n", "\n", "let's create some fake data and pass it through our simple fovi layer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Our foveated sensor outputs and KNNLayer inputs/outputs are formatted as [batch, num_channels, num_coords]\n", "x_sensor = torch.rand([64, 3, len(sensor_coords)]).to(device)\n", "\n", "layer_output = full_layer(x_sensor)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 128, 964])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "layer_output.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In practice, the `KNNAlexNetBlock` is set up to do exactly what we just did: combine conv, pooling, nonlinearity, and normalization." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from fovi.arch.knnalexnet import KNNAlexNetBlock\n", "\n", "block = KNNAlexNetBlock(3, channels, k_conv, fov, cmf_a, cartesian_res, conv_stride, cart_res=cartesian_res, pool=True, pool_k=k_pool, pool_stride=pool_stride, norm_type='batch', auto_match_cart_resources=1, device=device, ref_frame_mult=2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 128, 964])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "block_output = block(x_sensor)\n", "\n", "block_output.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building an AlexNet-like KNN model\n", "\n", "Now that we've seen the layers and blocks, we are ready to build a complete KNNAlexNet model. For this, we will just use our wrapper function, and refer you to the code for further detail. The `KNNAlexNet` class builds AlexNet-like models, but is more flexible to different numbers of layers, different kernel sizes, channels dimensions, etc. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "no output pooling layer\n" ] } ], "source": [ "from fovi.arch.knnalexnet import KNNAlexNet\n", "\n", "model = KNNAlexNet(\n", " cartesian_res, \n", " 3, \n", " [96, 256, 384, 256, 256], # channels per layer\n", " [4,1,1,1,1], # conv stride per layer\n", " [1,4], # pool after\n", " [11**2, 5**2, 3**2, 3**3, 3**2], # k per layer\n", " n_classes=1000,\n", " out_res=None, # no output pooling\n", " auto_match_cart_resources=1,\n", " norm_type='batch',\n", " ref_frame_mult=2,\n", " fov=fov,\n", " cmf_a=cmf_a,\n", " device=device,\n", " )" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KNNAlexNet(\n", " (layers): ModuleList(\n", " (0): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=3\n", " \tout_channels=96\n", " \tk=121\n", " \tn_ref=484\n", " \tin_coords=SamplingCoords(length=4085, fov=16, cmf_a=0.5, resolution=53, style=isotropic)\n", " \tout_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " )\n", " (1): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=96\n", " \tout_channels=256\n", " \tk=25\n", " \tn_ref=100\n", " \tin_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tout_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " (pool): KNNPoolingLayer(\n", " \tmode=max\n", " \tk=9\n", " \tin_coords=SamplingCoords(length=230, fov=16, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " )\n", " (2): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=256\n", " \tout_channels=384\n", " \tk=9\n", " \tn_ref=36\n", " \tin_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " )\n", " (3): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=384\n", " \tout_channels=256\n", " \tk=27\n", " \tn_ref=121\n", " \tin_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " )\n", " (4): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=256\n", " \tout_channels=256\n", " \tk=9\n", " \tn_ref=36\n", " \tin_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " (pool): KNNPoolingLayer(\n", " \tmode=max\n", " \tk=9\n", " \tin_coords=SamplingCoords(length=60, fov=16, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=16, fov=16, cmf_a=0.5, resolution=4, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " )\n", " )\n", " (classifier): Linear(in_features=4096, out_features=1000, bias=True)\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a complete FoviNet\n", "\n", "You may have noticed that we have only worked with fake data thus far. Our KNN architectures are designed to work with outputs formatted on the sensor manifold. To get this from images, we need to use a `RetinalTransform` object, or more simply, a `SaccadePolicy` which will also determine our fixations for us. If you forget about these, go back to `step1_sampling.ipynb`. \n", "\n", "The last piece of the puzzle is the `FoviNet` class which combines a fixation policy with a processing network. Since there are a lot of hyperparameters, here, we specify them all in a neat hierarchical `config`. This is typically specified as a `.yaml` file, and we use `hydra` and `omega` to handle these in our training scripts. Because the config uses inheritance, we initialize and compose it with `hydra`. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "adjusting FOV for fixation: 16.0 (full: 16.0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/n/home12/nblauch/git/foveation-private/fovi/sensing/coords.py:349: RuntimeWarning: divide by zero encountered in scalar divide\n", " w_delta = (w_max - w_min)/(res-1)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Note: horizontal flip always done in the loader, to avoid differences across fixations\n", "Number of coords per layer: [4085, 964, 230, 230, 60, 60, 60, 60, 16, 1]\n" ] } ], "source": [ "from fovi.fovinet import FoviNet\n", "from hydra import compose, initialize\n", "\n", "# Use hydra/omega to process the hierarchical config, including all defaults\n", "with initialize(version_base=None, config_path=\"../config\"):\n", " config = compose(config_name=\"fovi_alexnet.yaml\")\n", "print(type(config)) # OmegaConf DictConfig\n", "\n", "fovinet = FoviNet(config, device=device)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FoviNet(\n", " (network): BackboneProjectorWrapper(\n", " (backbone): KNNAlexNet(\n", " (layers): ModuleList(\n", " (0): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=3\n", " \tout_channels=96\n", " \tk=121\n", " \tn_ref=484\n", " \tin_coords=SamplingCoords(length=4085, fov=16.0, cmf_a=0.5, resolution=53, style=isotropic)\n", " \tout_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " (pool): KNNPoolingLayer(\n", " \tmode=max\n", " \tk=9\n", " \tin_coords=SamplingCoords(length=964, fov=16.0, cmf_a=0.5, resolution=26, style=isotropic)\n", " \tout_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " )\n", " (1): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=96\n", " \tout_channels=256\n", " \tk=25\n", " \tn_ref=100\n", " \tin_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tout_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " (pool): KNNPoolingLayer(\n", " \tmode=max\n", " \tk=9\n", " \tin_coords=SamplingCoords(length=230, fov=16.0, cmf_a=0.5, resolution=13, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " )\n", " (2): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=256\n", " \tout_channels=384\n", " \tk=9\n", " \tn_ref=36\n", " \tin_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " )\n", " (3): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=384\n", " \tout_channels=384\n", " \tk=9\n", " \tn_ref=36\n", " \tin_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " )\n", " (4): KNNAlexNetBlock(\n", " (conv): KNNConvLayer(\n", " \tin_channels=384\n", " \tout_channels=256\n", " \tk=9\n", " \tn_ref=36\n", " \tin_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " (norm): KNNBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " (pool): KNNPoolingLayer(\n", " \tmode=max\n", " \tk=9\n", " \tin_coords=SamplingCoords(length=60, fov=16.0, cmf_a=0.5, resolution=7, style=isotropic)\n", " \tout_coords=SamplingCoords(length=16, fov=16.0, cmf_a=0.5, resolution=4, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " )\n", " (5): KNNPoolingLayer(\n", " \tmode=avg\n", " \tk=16\n", " \tin_coords=SamplingCoords(length=16, fov=16.0, cmf_a=0.5, resolution=4, style=isotropic)\n", " \tout_coords=SamplingCoords(length=1, fov=16.0, cmf_a=0.5, resolution=1, style=isotropic)\n", " \tsample_cortex=True\n", " )\n", " )\n", " )\n", " (projector): MLPWrapper(\n", " (layers): Sequential(\n", " (fc_block_6): LayerBlock(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Linear(in_features=256, out_features=1024, bias=True)\n", " (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): ReLU(inplace=True)\n", " )\n", " (fc_block_7): LayerBlock(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Linear(in_features=1024, out_features=1024, bias=False)\n", " )\n", " )\n", " )\n", " )\n", " (retinal_transform): RetinalTransform(\n", " (foveal_color): GaussianColorDecay(sigma=None)\n", " (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)\n", " )\n", " (ssl_fixator): NoSaccadePolicy(\n", " retinal_transform=RetinalTransform(\n", " (foveal_color): GaussianColorDecay(sigma=None)\n", " (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)\n", " ),\n", " n_fixations=1\n", " )\n", " (sup_fixator): MultiRandomSaccadePolicy(\n", " retinal_transform=RetinalTransform(\n", " (foveal_color): GaussianColorDecay(sigma=None)\n", " (sampler): GridSampler(fov=16.0, cmf_a=0.5, style=isotropic, resolution=53, mode=nearest, n=4085)\n", " ),\n", " n_fixations=4,\n", " nonrandom_first=False,\n", " nonrandom_val=False,\n", " crop_area_range=[1.0, 1.0],\n", " add_aspect_variation=None,\n", " val_crop_size=1.0,\n", " norm_dist_from_center=0.25\n", " )\n", " (head): FoviNetProbe(\n", " (fix_projector): LinearProbe(\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (probe): Linear(in_features=1024, out_features=1000, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fovinet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now process image data with our fovinet model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 3, 256, 256])\n" ] } ], "source": [ "from fovi.demo import get_image_as_batch\n", "\n", "# data = torch.rand([128, 3, 256, 256]).to(device)\n", "\n", "data = get_image_as_batch(device=device)\n", "\n", "print(data.shape)\n", "\n", "category_logits, layer_outputs, x_fixs = fovinet(data)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 4, 256])\n", "torch.Size([1, 4, 1024])\n", "torch.Size([1, 1000])\n" ] } ], "source": [ "# global avg pool of final conv layer\n", "print(layer_outputs[0].shape) # (batch, num_fixations, conv_dim)\n", "# final MLP layer:\n", "print(layer_outputs[-1].shape) # (batch, num_fixations, fc_dim)\n", "# category logits averaged across fixations\n", "print(category_logits.shape) # (batch, num_classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# load a pre-trained model" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "adjusting FOV for fixation: 16.0 (full: 16.0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Note: horizontal flip always done in the loader, to avoid differences across fixations\n", "Number of coords per layer: [4085, 964, 230, 230, 60, 60, 60, 60, 16, 1]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fovi import load_config\n", "\n", "base_fn = 'fovi-alexnet_a-1_res-64_in1k'\n", "config, state_dict, model_key = load_config(base_fn, load=True, folder='../models', device='cpu')\n", "model = FoviNet(config, device='cpu')\n", "model.load_state_dict(state_dict[model_key])" ] } ], "metadata": { "kernelspec": { "display_name": "new_workshop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }