Source code for fovi

from omegaconf import OmegaConf, open_dict
import torch.nn as nn
from typing import Type
from numba.core.config import NUMBA_NUM_THREADS

from .utils import HiddenPrints
from .paths import SAVE_DIR, SLOW_DIR
from .trainer import Trainer, find_config, load_config
from .fovinet import FoviNet

[docs] def get_trainer_from_base_fn(base_fn, load=True, load_strict=True, quiet=False, allow_distributed=False, gpu=0, model_dirs=['../models', SAVE_DIR + '/logs', SLOW_DIR + '/logs'], **kwargs, ): """ Get a Trainer instance based on a base filename and optional parameters. This function loads a model configuration and optionally its weights from a specified directory, creates a Trainer instance with the loaded configuration, and returns it. Args: base_fn (str): The base filename to look for in the logs directory. load (bool, optional): Whether to load the model weights. Defaults to True. load_strict (bool, optional): Whether to strictly enforce matching keys when loading weights. Defaults to True. quiet (bool, optional): Whether to suppress print statements. Defaults to False. allow_distributed (bool, optional): Whether to allow distributed training configuration. Defaults to False. **kwargs: Additional keyword arguments to override or add to the configuration. Returns: Trainer: An instance of Trainer with the specified configuration and optionally loaded weights. Note: The function searches for the model in both SLOW_DIR and SAVE_DIR. It prioritizes loading final weights over non-final weights if available. """ with HiddenPrints(quiet): cfg, state_dict, model_key = find_config(base_fn, load=load, model_dirs=model_dirs) cfg.logging.use_wandb = 0 # we changed from specifying foveal diameter "fovea" ($2a$) to specifying $a$ directly if not hasattr(cfg.saccades, 'cmf_a'): cfg.saccades.cmf_a = cfg.saccades.fovea/2 load_head = True for k, v in kwargs.items(): if k == 'data.num_classes' and cfg.data.num_classes != v: load_head = False # changing # of classes, this will mess up state dict loading, so we don't load the head with open_dict(cfg): OmegaConf.update(cfg, k, v) if not allow_distributed and cfg.training.distributed: prev_num_workers = cfg.data.num_workers prev_world_size = cfg.dist.world_size prev_ngpus = cfg.dist.ngpus prev_nodes = cfg.dist.nodes if prev_num_workers is None: cfg.data.num_workers = NUMBA_NUM_THREADS - 2 else: cfg.data.num_workers = prev_num_workers // prev_world_size cfg.dist.world_size = 1 cfg.dist.ngpus = 1 cfg.dist.nodes = 1 cfg.training.distributed = 0 if 'logging.base_fn' in kwargs and 'logging.folder' not in kwargs: # update logging folder if updating base_fn cfg.logging.folder = f"{SAVE_DIR}/logs/{cfg.logging.base_fn}" trainer = Trainer(gpu, cfg, load_checkpoint=False) if load: if not allow_distributed: # check for any module. keys: keys = [model_key] if 'probes' in state_dict: keys.append('probes') for this_key in keys: new_state_dict = {} for k, v in state_dict[this_key].items(): if 'head.' in k and not load_head: continue if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v state_dict[this_key] = new_state_dict trainer.model.load_state_dict(state_dict[model_key], strict=load_strict) if load_head and 'probes' in state_dict: trainer.probes.load_state_dict(state_dict['probes'], strict=load_strict) return trainer
[docs] def get_model_from_base_fn(base_fn, load=True, load_strict=True, quiet=False, device='cuda', model_dirs=['../models'], fovinet_cls: Type[nn.Module] = FoviNet, **kwargs, ): """ Get a FoviNet instance based on a base filename and optional parameters. This function loads a model configuration and optionally its weights from a specified directory, creates an FoviNet instance with the loaded configuration, and returns it. Args: base_fn (str): The base filename to look for in the logs directory. load (bool, optional): Whether to load the model weights. Defaults to True. load_strict (bool, optional): Whether to strictly enforce matching keys when loading weights. Defaults to True. quiet (bool, optional): Whether to suppress print statements. Defaults to False. device (str, optional): Device to load the model on. Defaults to 'cuda'. **kwargs: Additional keyword arguments to override or add to the configuration. Returns: FoviNet: An instance of FoviNet with the specified configuration and optionally loaded weights. Note: The function searches for the model in both SLOW_DIR and SAVE_DIR. It prioritizes loading final weights over non-final weights if available. """ with HiddenPrints(quiet): cfg, state_dict, model_key = find_config(base_fn, load=load, model_dirs=model_dirs) if hasattr(cfg, 'logging'): cfg.logging.use_wandb = 0 load_head = True for k, v in kwargs.items(): if k == 'data.num_classes' and cfg.data.num_classes != v: load_head = False # changing # of classes, this will mess up state dict loading, so we don't load the head with open_dict(cfg): OmegaConf.update(cfg, k, v) if 'logging.base_fn' in kwargs and 'logging.folder' not in kwargs: # update logging folder if updating base_fn cfg.logging.folder = f"{SAVE_DIR}/logs/{cfg.logging.base_fn}" model = fovinet_cls(cfg, device=device) if load: # check for any module. keys from distributed training new_state_dict = {} for k, v in state_dict[model_key].items(): if 'head.' in k and not load_head: continue if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v state_dict[model_key] = new_state_dict model.load_state_dict(state_dict[model_key], strict=load_strict) return model