cd

Primary symbols include classes and functions for visualization, configuration, timing and generally useful utilities.

Visualization

label_cmap(labels: ndarray, colors: str | ndarray = 'rand', zero_val: float | tuple | list = 0.0, rgba: bool = True, alpha: float | None = None)

Label colormap.

Applies a colormap to a label image.

Parameters:
  • labels – Label image. Typically Array[h, w].

  • colors – Either ‘rand’ or one of [‘Pastel1’, ‘Pastel2’, ‘Paired’, ‘Accent’, ‘Dark2’, ‘Set1’, ‘Set2’, ‘Set3’, ‘tab10’, ‘tab20’, ‘tab20b’, ‘tab20c’] (see matplotlib’s qualitative colormaps) or Array[n, c].

  • zero_val – Special color for the zero label (usually background).

  • rgba – Whether to add an alpha channel to rgb colors.

  • alpha – Specific alpha value.

Returns:

Mapped labels. E.g. rgba mapping Array[h, w] -> Array[h, w, 4].

random_colors_hsv(num, hue_range=(0, 180), saturation_range=(60, 133), value_range=(180, 256))
get_axes(fig=None) List[SubplotBase]

Get current pyplot axes.

Parameters:

fig – Optional Figure.

Returns:

List of Axes.

imshow(image: ndarray | Tensor, figsize=None, **kw)

Imshow.

PyPlot’s imshow function with benefits.

Parameters:
  • image – Image. Valid Formats: Array[h, w], Array[h, w, c] or Array[n, h, w, c], Tensor[h, w], Tensor[c, h, w] or Tensor[n, c, h, w]. Images without channels or just one channel are plotted as grayscale images by default.

  • figsize – Figure size. If specified, a new plt.figure(figsize=figsize) is created.

  • **kw – Imshow keyword arguments.

imshow_col(*images, titles=None, figsize=(3, 3), tight=True, **kwargs)

Imshow row.

Display a list of images in a column.

Parameters:
  • *images – Images.

  • titles – Titles. Either string or list of strings (one for each image).

  • figsize – Figure size per image.

  • tight – Whether to use tight layout.

  • **kwargs – Keyword arguments for cd.imshow.

imshow_grid(*images, titles=None, figsize=(3, 3), tight=True, **kwargs)

Imshow grid.

Display a list of images in a NxN grid.

Parameters:
  • *images – Images.

  • titles – Titles. Either string or list of strings (one for each image).

  • figsize – Figure size per image.

  • tight – Whether to use tight layout.

  • **kwargs – Keyword arguments for cd.imshow.

imshow_row(*images, titles=None, figsize=(3, 3), tight=True, **kwargs)

Imshow row.

Display a list of images in a row.

Parameters:
  • *images – Images.

  • titles – Titles. Either string or list of strings (one for each image).

  • figsize – Figure size per image.

  • tight – Whether to use tight layout.

  • **kwargs – Keyword arguments for cd.imshow.

plot_box(x_min, y_min, x_max, y_max, linewidth=1, edgecolor='#4AF626', facecolor='none', text=None, **kwargs)
plot_boxes(boxes, texts: List[str] | None = None, **kwargs)
plot_contours(contours, contour_line_width=2, contour_linestyle='-', fill=0.2, color=None, texts: list | None = None, **kwargs)
plot_mask(mask, alpha=1)
plot_score(score: float, x, y, cls: int | str | None = None, cls_names: dict | None = None, **kwargs)
plot_text(text, x, y, color='black', stroke_width=5, stroke_color='w')
quiver_plot(vector_field, image=None, cmap='gray', figsize=None, qcmap='twilight', linewidth=0.125, width=0.19, alpha=0.7)

Quiver plot.

Plots a 2d vector field. Can be used to visualize local refinement tensor.

Parameters:
  • vector_field – Array[2, w, h].

  • image – Array[h, w(, 3)].

  • cmap – Image color map.

  • figsize – Figure size.

  • qcmap – Quiver color map. Consider seaborn’s: qcmap = ListedColormap(sns.color_palette(“husl”, 8).as_hex())

  • linewidth – Quiver line width.

  • width – Quiver width.

  • alpha – Quiver alpha.

save_fig(filename, close=True)

Save Figure.

Save current Figure to disk.

Parameters:
  • filename – Filename, e.g. image.png.

  • close – Whether to close all unhandled Figures. Do not close them if you intend to call plt.show().

show_detection(image=None, contours=None, coordinates=None, boxes=None, scores=None, masks=None, figsize=None, label_stack=None, classes: List[str] | List[int] | None = None, class_names: dict | None = None, contour_line_width=2, contour_linestyle='-', fill=0.2, cmap=Ellipsis, **kwargs)

Config

class Config(**kwargs)

Config.

Just a dict with benefits.

Config objects treat values as attributes, print nicely, and can be saved and loaded to/from json files. The hash method also offers a unique and compact string representation of the Config content.

Examples

>>> import celldetection as cd, torch.nn as nn
>>> conf = cd.Config(optimizer={'Adam': dict(lr=.001)}, epochs=100)
>>> conf
Config(
  (optimizer): {'Adam': {'lr': 0.001}}
  (epochs): 100
)
>>> conf.to_json('config.json')
>>> conf.hash()
'cf647b987ca37eb954d8bd01df01809e'
>>> conf.epochs = 200
... conf.epochs
200
>>> module = nn.Conv2d(1, 2, 3)
>>> optimizer = cd.conf2optimizer(conf.optimizer, module.parameters())
... optimizer
Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        eps: 1e-08
        lr: 0.001
        weight_decay: 0
)
Parameters:

**kwargs – Items.

args(fn: Callable)

Examples

>>> conf = cd.Config(a=1, b=2, c=42)
>>> def f(a, b):
...     return a + b
>>> f(*conf.args(f))
3
Parameters:

fn

Returns:

extra_repr() str
static from_json(filename)
hash() str
kwargs(fn: Callable)

Examples

>>> conf = cd.Config(a=1, b=2, c=42)
>>> def f(a, b):
...     return a + b
>>> f(**conf.kwargs(f))
3
Parameters:

fn

Returns:

load(filename)
to_dict() dict
to_json(filename)
to_txt(filename, mode='w', **kwargs)
class Schedule(**kwargs)

Schedule.

Provides an easy interface to the cross product of different configurations.

Examples

>>> s = cd.Schedule(
...     lr=(0.001, 0.0005),
...     net=('resnet34', 'resnet50'),
...     epochs=100
... )
... len(s)
4
>>> s[:]
[Config(
  (epochs): 100
  (lr): 0.001
  (net): 'resnet34'
), Config(
  (epochs): 100
  (lr): 0.001
  (net): 'resnet50'
), Config(
  (epochs): 100
  (lr): 0.0005
  (net): 'resnet34'
), Config(
  (epochs): 100
  (lr): 0.0005
  (net): 'resnet50'
)]
>>> for config in s:
...     print(config.lr, config.net, config.epoch)
0.001 resnet34 100
0.001 resnet50 100
0.0005 resnet34 100
0.0005 resnet50 100
Parameters:

**kwargs – Configurations. Possible item layouts: <name>: <static setting>, <name>: (<option1>, ..., <optionN>), <name>: [<option1>, ..., <optionN>], <name>: {<option1>, ..., <optionN>}.

add(d: dict | None = None, conditions: dict | None = None, **kwargs)

Add setting to schedule.

Examples

>>> schedule = cd.Schedule(model=('resnet18', 'resnet50'), batch_size=8)
... schedule.add(batch_size=(16, 32), conditions={'model': 'resnet18'})
... schedule[:]
[Config(
   (batch_size): 16,
   (model): resnet18,
 ),
 Config(
   (batch_size): 32,
   (model): resnet18,
 ),
 Config(
   (batch_size): 8,
   (model): resnet50,
 )]
>>> schedule = cd.Schedule(model=('resnet18', 'resnet50'))
... schedule.add(batch_size=(16, 32), conditions={'model': 'resnet18'})
... schedule[:]
[Config(
   (model): resnet18,
   (batch_size): 16,
 ),
 Config(
   (model): resnet18,
   (batch_size): 32,
 ),
 Config(
   (model): resnet50,
 )]
>>> schedule = cd.Schedule(model=('resnet18', 'resnet50'), batch_size=(64, 128, 256))
... schedule.add(batch_size=(16, 32), conditions={'model': 'resnet50'})
... schedule[:]
[Config(
   (batch_size): 64
   (model): 'resnet18'
 ),
 Config(
   (batch_size): 16
   (model): 'resnet50'
 ),
 Config(
   (batch_size): 32
   (model): 'resnet50'
 ),
 Config(
   (batch_size): 128
   (model): 'resnet18'
 ),
 Config(
   (batch_size): 256
   (model): 'resnet18'
 )]
Parameters:
  • d – Dictionary of settings.

  • conditions – If set, added settings are only applied if conditions are met. Note: Conditioned settings replace/override existing settings if conditions are met.

  • **kwargs – Configurations. Possible item layouts: <name>: <static setting> <name>: (<option1>, …, <optionN>) <name>: [<option1>, …, <optionN>] <name>: {<option1>, …, <optionN>}

property configs
static from_json(filename)
get_multiples(num=2)
load(filename)
property product
to_dict()
to_dict_list()
to_json(filename)
conf2augmentation(settings: dict) Compose

Config to augmentation.

Maps settings to composed augmentation workflow using albumentations.

Examples

>>> import celldetection as cd
>>> cd.conf2augmentation({
...     'RandomRotate90': dict(p=.5),
...     'Transpose': dict(p=.5),
... })
Compose([
  RandomRotate90(always_apply=False, p=0.5),
  Transpose(always_apply=False, p=0.5),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})
Parameters:

settings – Settings dictionary as {name: kwargs}.

Returns:

A.Compose object.

conf2call(settings: dict | str, origin, **kwargs)

Config to call.

Examples

>>> import celldetection as cd
>>> model = cd.conf2call('ResNet18', cd.models, in_channels=1)
>>> model = cd.conf2call({'ResNet18': dict(in_channels=1)}, cd.models)
Parameters:
  • settings – Name or dictionary as {name: kwargs}. Name must be the symbol’s name that is to be retrieved from origin.

  • origin – Origin.

  • **kwargs – Additional keyword arguments for the call of retrieved symbol.

Returns:

Return value of the call of retrieved symbol.

conf2optimizer(settings: dict, params)

Config to optimizer.

Examples

>>> import celldetection as cd
>>> module = nn.Conv2d(1, 2, 3)
>>> optimizer = cd.conf2optimizer({'Adam': dict(lr=.0002, betas=(0.5, 0.999))}, module.parameters())
... optimizer
Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.5, 0.999)
        eps: 1e-08
        lr: 0.0002
        weight_decay: 0
)
Parameters:
  • settings

  • params

Returns:

conf2scheduler(settings: dict, optimizer, origins=None)
conf2tweaks_(settings: dict, module: Module)

Config to tweaks.

Apply tweaks to module.

Notes

  • If module does not contain specified objects, nothing happens.

Examples

>>> import celldetection as cd, torch.nn as nn
>>> model = cd.models.ResNet18(in_channels=3)
>>> cd.conf2tweaks_({nn.BatchNorm2d: dict(momentum=0.05)}, model)  # sets momentum to 0.05
>>> cd.conf2tweaks_({'BatchNorm2d': dict(momentum=0.42)}, model)  # sets momentum to 0.42
>>> cd.conf2tweaks_({'LeakyReLU': dict(negative_slope=0.2)}, model)  # sets negative_slope to 0.2
Parameters:
  • settings – Settings dictionary as {name: kwargs}.

  • module – Module that is to be tweaked.

Timing

print_timing(name, seconds)
start_timer(name, cuda=True, collect=True)

Keyword PyTorch timer.

Can be used to measure PyTorch GPU times.

Parameters:

name – Keyword

Returns:

stop_timer(name, cuda=True, verbose=True)

Util

class Bytes

Bytes.

Printable integer that represents Bytes.

UNITS = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB', 'BiB']
class Dict(**kwargs)

Dictionary.

Just a dict that treats values like attributes.

Examples

>>> import celldetection as cd
>>> d = cd.Dict(my_value=42)
>>> d.my_value
42
>>> d.my_value += 1
>>> d.my_value
43
Parameters:

**kwargs

class GpuStats(delimiter=', ')

GPU Statistics.

Simple interface to print live GPU statistics from pynvml.

Examples

>>> import celldetection as cd
>>> stat = cd.GpuStats()  # initialize once
>>> print(stat)  # print current statistics
gpu0(free: 22.55GB, used: 21.94GB, util: 93%), gpu1(free: 1.03GB, used: 43.46GB, util: 98%)
Parameters:

delimiter – Delimiter used for printing.

dict(byte_lvl=3, prefix='gpu')
class NormProxy(norm, **kwargs)

Norm Proxy.

Examples

>>> GroupNorm = NormProxy('groupnorm', num_groups=32)
... GroupNorm(3)
GroupNorm(32, 3, eps=1e-05, affine=True)
>>> GroupNorm = NormProxy(nn.GroupNorm, num_groups=32)
... GroupNorm(3)
GroupNorm(32, 3, eps=1e-05, affine=True)
>>> BatchNorm2d = NormProxy('batchnorm2d', momentum=.2)
... BatchNorm2d(3)
BatchNorm2d(3, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
>>> BatchNorm2d = NormProxy(nn.BatchNorm2d, momentum=.2)
... BatchNorm2d(3)
BatchNorm2d(3, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
Parameters:
  • norm – Norm class or name.

  • **kwargs – Keyword arguments.

class Percent(x=0, /)

Percent.

Printable float that represents percentage.

class Tiling(tile_size: tuple, context_shape: tuple, overlap=0)
add_to_loss_dict(d: dict, key: str, loss: Tensor, weight=None)
append_hash_to_filename(filename, num=None, ext=True)
asnumpy(v)

As numpy.

Converts all Tensors to numpy arrays.

Notes

  • Works recursively.

  • The following input items are not altered: Numpy array, int, float, bool, str

Parameters:

v – Tensor or list/tuple/dict of Tensors.

Returns:

Input with Tensors converted to numpy arrays.

base64_to_image(code, as_numpy=True)
copy_script(dst, no_script_okay=True, frame=None, verbose=False)

Copy current script.

Copies the script from where this function is called to dst. By default, nothing happens if this function is not called from within a script.

Parameters:
  • dst – Copy destination. Filename or folder.

  • no_script_okay – If False raise FileNotFoundError if no script is found.

  • frame – Context frame.

  • verbose – Whether to print source and destination when copying.

count_submodules(module: Module, class_or_tuple) int

Count submodules.

Count the number of submodules of the specified type(-es).

Examples

>>> count_submodules(cd.models.U22(1, 0), nn.Conv2d)
22
Parameters:
  • module – Module.

  • class_or_tuple – All instances of given class_or_tuple are to be counted.

Returns:

Number of submodules.

dict2model(conf, **kwargs)
dict_hash(dictionary: Dict[str, Any]) str

MD5 hash of a dictionary.

References

https://www.doc.ic.ac.uk/~nuric/coding/how-to-hash-a-dictionary-in-python.html

Parameters:

dictionary – A dictionary.

Returns:

Md5 hash of the dictionary as string.

ensure_num_tuple(v, num=2, msg='')
exponential_moving_average_(module_avg, module, alpha=0.999, alpha_non_trainable=0.0, buffers=True)

Exponential moving average.

Update the variables of module_avg to be slightly closer to module.

References

Notes

  • Whether a parameter is trainable or not is checked on module

  • module_avg can be on different device and entirely frozen

Parameters:
  • module_avg – Average module. The parameters of this model are to be updated.

  • module – Other Module.

  • alpha – Fraction of trainable parameters of module_avg; (1 - alpha) is fraction of trainable parameters of module.

  • alpha_non_trainable – Same as alpha, but for non-trainable parameters.

  • buffers – Whether to copy buffers from module to module_avg.

fetch_image(url, numpy=True)

Fetch image from URL.

Download an image from URL and convert it to a numpy array or PIL Image.

Parameters:
  • url – URL

  • numpy – Whether to convert PIL Image to numpy array.

Returns:

PIL Image or numpy array.

fetch_model(name, map_location=None, **kwargs)

Fetch model from URL.

Loads model or state dict from URL.

Parameters:
  • name – Model name hosted on celldetection.org or url. Urls must start with ‘http’.

  • map_location – A function, torch.device, string or a dict specifying how to remap storage locations.

  • **kwargs – From the doc of torch.models.utils.load_state_dict_from_url.

freeze_(module: Module, recurse=True)

Freeze.

Freezes a module by setting param.requires_grad=False and calling module.eval().

Parameters:
  • module – Module.

  • recurse – Whether to freeze parameters of this layer and submodules or only parameters that are direct members of this module.

freeze_submodules_(module: Module, *names, recurse=True)

Freeze specific submodules.

Freezes submodules by setting param.requires_grad=False and calling submodule.eval().

Parameters:
  • module – Module.

  • names – Names of submodules.

  • recurse – Whether to freeze parameters of specified modules and their respective submodules or only parameters that are direct members of the specified submodules.

from_h5(filename, *keys, **keys_slices)

From h5.

Reads data from hdf5 file.

Parameters:
  • filename – Filename.

  • *keys – Keys to read.

  • **keys_slices – Keys with indices or slices. E.g. from_h5(‘file.h5’, ‘key0’, key=slice(0, 42)).

Returns:

Data from hdf5 file. As tuple if multiple keys are provided.

from_json(filename)

From JSON.

Load object from JSON file with name filename.

Parameters:

filename – File name.

frozen_params(module: Module, recurse=True) Iterator[Parameter]

Frozen parameters.

Retrieve all frozen parameters.

Parameters:
  • module – Module.

  • recurse – Whether to also include parameters of all submodules.

Returns:

Module parameters.

gaussian_kernel(kernel_size, sigma=-1, nd=2) ndarray

Get Gaussian kernel.

Constructs and returns a Gaussian kernel.

Parameters:
  • kernel_size – Kernel size as int or tuple. It should be odd and positive.

  • sigma – Gaussian standard deviation as float or tuple. If it is non-positive, it is computed from kernel_size as sigma = 0.3*((kernel_size-1)*0.5 - 1) + 0.8.

  • nd – Number of kernel dimensions.

Returns:

Gaussian Kernel.

get_device(module: Module | Tensor | device)

Get device.

Get device from Module.

Parameters:

module – Module. If module is a string or torch.device already, it is returned as is.

Returns:

Device.

get_nd_batchnorm(dim: int)
get_nd_conv(dim: int)
get_nd_dropout(dim: int)
get_nd_linear(dim: int)
get_nd_max_pool(dim: int)
get_nn(item: str | Module | Type[Module], src=None, nd=None)
get_tiling_slices(size: Sequence[int], crop_size: int | Sequence[int], strides: int | Sequence[int], return_overlaps=False) Tuple[Iterable[slice], Tuple[int]] | Tuple[Iterable[slice], Iterable[Tuple[int]], Tuple[int]]

Get tiling slices.

Parameters:
  • size – Reference size as tuple.

  • crop_size – Crop size.

  • strides – Strides.

  • return_overlaps – Whether to return overlaps.

Returns:

Iterator of tiling slices (each slice defining a tile),

Number of tiles per dimension as tuple.

Iterable[slice], Iterable[Tuple[int]], Tuple[int]:

Iterator of tiling slices (each slice defining a tile), Iterator of overlaps (overlaps with adjacent tiles for each tile), Number of tiles per dimension as tuple.

Return type:

Iterable[slice], Tuple[int]

get_warmup_factor(step, steps=1000, factor=0.001, method='linear')
hash_file(filename)
image_to_base64(img: ndarray, ext='png', as_url=True, url_template=None)
inject_extra_repr_(module, name, fn)

Inject extra representation.

Injects additional extra_repr function to module. This can be helpful to indicate presence of hooks.

Note

This is an inplace operation.

Notes

  • This op may impair pickling.

Parameters:
  • module – Module.

  • name – Name of the injected function (only used to avoid duplicate injection).

  • fn – Callback function.

iter_submodules(module: Module, class_or_tuple, recursive=True)
load_image(name, method='imageio') ndarray

Load image.

Load image from URL or from filename via imageio or pytiff.

Parameters:
  • name – URL (must start with http) or filename.

  • method – Method to use for filenames.

Returns:

Image.

load_model(filename, map_location=None, **kwargs)
lookup_nn(item: str, *a, src=None, call=True, inplace=True, nd=None, **kw)

Examples

>>> lookup_nn('batchnorm2d', 32)
    BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>>> lookup_nn(torch.nn.BatchNorm2d, 32)
    BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>>> lookup_nn('batchnorm2d', num_features=32)
    BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>>> lookup_nn('tanh')
    Tanh()
>>> lookup_nn('tanh', call=False)
    torch.nn.modules.activation.Tanh
>>> lookup_nn('relu')
    ReLU(inplace=True)
>>> lookup_nn('relu', inplace=False)
    ReLU()
>>> # Dict notation to contain all keyword arguments for calling in `item`. Always called once.
... lookup_nn(dict(relu=dict(inplace=True)), call=False)
    ReLU(inplace=True)
>>> lookup_nn({'NormProxy': {'norm': 'GroupNorm', 'num_groups': 32}}, call=False)
    NormProxy(GroupNorm, kwargs={'num_groups': 32})
>>> lookup_nn({'NormProxy': {'norm': 'GroupNorm', 'num_groups': 32}}, 32, call=True)
    GroupNorm(32, 32, eps=1e-05, affine=True)
Parameters:
  • item – Lookup item. None is equivalent to identity.

  • *a – Arguments passed to item if called.

  • src – Lookup source.

  • call – Whether to call item.

  • inplace – Default setting for items that take an inplace argument when called. As default is True, lookup_nn(‘relu’) returns a ReLu instance with inplace=True.

  • nd – If set, replace dimension statement (e.g. ‘2d’ in nn.Conv2d) with nd.

  • **kw – Keyword arguments passed to item when it is called.

Returns:

Looked up item.

model2dict(model: Module)
num_bytes(x: ndarray | Tensor)

Num Bytes.

Returns the size in bytes of the given ndarray or Tensor.

Parameters:

x – Array or Tensor.

Returns:

Bytes

num_params(module: Module, trainable=None, recurse=True) int

Number of parameters.

Count the number of parameters.

Parameters:
  • module – Module

  • trainable – Optionally filter for trainable or frozen parameters.

  • recurse – Whether to also include parameters of all submodules.

Returns:

Number of parameters.

print_to_file(*args, filename, mode='w', **kwargs)
random_code_name(chars=4) str

Random code name.

Generates random code names that are somewhat pronounceable and memorable.

Examples

>>> import celldetection as cd
>>> cd.random_code_name()
kolo
>>> cd.random_code_name(6)
lotexo
Parameters:

chars – Number of characters.

Returns:

String.

random_code_name_dir(directory='./out', chars=6, comm=None, root_rank=0)

Random code name directory.

Creates random code name and creates a subdirectory with said name under directory. Code names that are already taken (subdirectory already exists) are not reused.

Parameters:
  • directory – Root directory.

  • chars – Number of characters for the code name.

  • comm – MPI Comm. If provided, code name and directory is automatically broadcasted to all ranks of comm.

  • root_rank – Root rank. Only the root rank creates code name and directory.

Returns:

Tuple of code name and created directory.

random_seed(seed, backends=False, deterministic_torch=True)

Set random seed.

Set random seed to random, np.random, torch.backends.cudnn and torch.manual_seed. Also advise torch to use deterministic algorithms.

References

https://pytorch.org/docs/stable/notes/randomness.html

Parameters:
  • seed – Random seed.

  • backends – Whether to also adapt backends. If set True cuDNN’s benchmark feature is disabled. This causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance. Also the selected algorithm is set to run deterministically.

  • deterministic_torch – Whether to set PyTorch operations to behave deterministically.

reduce_loss_dict(losses: dict, divisor)
replace_module_(module: Module, class_or_tuple, substitute: Type[Module] | Module, recursive=True, inherit_attr: list | str | dict | None = None, **kwargs)

Replace module.

Replace all occurrences of class_or_tuple in module with substitute.

Examples

>>> # Replace all ReLU activations with LeakyReLU
... cd.replace_module_(network, nn.ReLU, nn.LeakyReLU)
>>> # Replace all BatchNorm layers with InstanceNorm and inherit `num_features` attribute
... cd.replace_module_(network, nn.BatchNorm2d, nn.InstanceNorm2d, inherit_attr=['num_features'])
>>> # Replace all BatchNorm layers with GroupNorm and inherit `num_features` attribute
... cd.replace_module_(network, nn.BatchNorm2d, nn.GroupNorm, num_groups=32,
...                    inherit_attr={'num_channels': 'num_features'})
Parameters:
  • module – Module.

  • class_or_tuple – Class or tuple of classes that are to be replaced.

  • substitute – Substitute class or object.

  • recursive – Whether to replace modules recursively.

  • inherit_attr – Attributes to be inherited. String, list or dict of attribute names. Attribute values are retrieved from replaced module and passed to substitute constructor. Formats: 'attr_name', ['attr_name0', 'attr_name1', ...], {'substitute_kw0': 'attr_name0', ...}

  • **kwargs – Keyword arguments passed to substitute constructor if it is a class.

save_fetchable_model(model: Module, filename, append_hash=16)
spectral_norm_(module, class_or_tuple=<class 'torch.nn.modules.conv.Conv2d'>, recursive=True, name='weight', add_repr=False, **kwargs)

Spectral normalization.

Applies spectral normalization to parameters of all occurrences of class_or_tuple in the given module.

Note

This is an inplace operation.

References

Parameters:
  • module – Module.

  • class_or_tuple – Class or tuple of classes whose parameters are to be normalized.

  • recursive – Whether to search for modules recursively.

  • name – Name of weight parameter.

  • add_repr – Whether to indicate use of spectral norm in a module’s representation. Note that this may impair pickling.

  • **kwargs – Additional keyword arguments for torch.nn.utils.spectral_norm.

tensor_to(inputs: list | tuple | dict | Tensor, *args, **kwargs)

Tensor to device/dtype/other.

Recursively calls tensor.to(*args, **kwargs) for all Tensors in inputs.

Notes

  • Works recursively.

  • Non-Tensor items are not altered.

Parameters:
  • inputs – Tensor, list, tuple or dict. Non-Tensor objects are ignored. Tensors are substituted by result of tensor.to(*args, **kwargs) call.

  • *args – Arguments. See docstring of torch.Tensor.to.

  • **kwargs – Keyword arguments. See docstring of torch.Tensor.to.

Returns:

Inputs with Tensors replaced by tensor.to(*args, **kwargs).

to_device(batch: list | tuple | dict | Tensor, device)

To device.

Move Tensors to device. Input can be Tensor, tuple of Tensors, list of Tensors or a dictionary of Tensors.

Notes

  • Works recursively.

  • Non-Tensor items are not altered.

Parameters:
  • batch – Tensor, list, tuple or dict. Non-Tensor objects are ignored. Tensors are moved to device.

  • device – Device.

Returns:

Input with Tensors moved to device.

to_h5(filename, mode='w', chunks=None, compression=None, overwrite=False, create_dataset_kw: dict | None = None, **kwargs)

To hdf5 file.

Write data to hdf5 file.

Parameters:
  • filename – File name.

  • mode – Mode.

  • chunks – Chunks setting for created datasets. Chunk shape, or True to enable auto-chunking.

  • compression – Compression setting for created datasets. Legal values are ‘gzip’, ‘szip’, ‘lzf’. If an integer in range(10), this indicates gzip compression level. Otherwise, an integer indicates the number of a dynamically loaded compression filter.

  • overwrite – Whether to overwrite existing dataset.

  • create_dataset_kw – Additional keyword arguments for h5py.File().create_dataset.

  • **kwargs – Data as {dataset_name: data}.

to_json(filename, obj, mode='w')

To JSON.

Dump obj to JSON file with name filename.

Parameters:
  • filename – File name.

  • obj – Object.

  • mode – File mode.

to_tiff(filename, image, mode='w', method='tile', bigtiff=True)

To tiff file.

Write image to tiff file using pytiff. By default, the tiff is tiled, s.t. crops can be read from disk without loading the entire image into memory first.

Notes

  • pytiff must be installed to use this function.

References

https://pytiff.readthedocs.io/en/master/quickstart.html

Parameters:
  • filename – File name.

  • image – Image.

  • mode – Mode.

  • method – Method. Either 'tile' or 'scanline'.

  • bigtiff – Whether to use bigtiff format.

train_epoch(model, train_loader, device, optimizer, desc=None, scaler=None, scheduler=None, gpu_stats=False, progress=True)

Basic train function.

Notes

  • Model should return dictionary: {‘loss’: Tensor[], …}

  • Batch from train_loader should be a dictionary: {‘inputs’: Tensor[…], …}

  • Model must be callable: model(batch[‘inputs’], targets=batch)

Parameters:
  • model – Model.

  • train_loader – Data loader.

  • device – Device.

  • optimizer – Optimizer.

  • desc – Description, appears in progress print.

  • scaler – Gradient scaler. If set PyTorch’s autocast feature is used.

  • scheduler – Scheduler. Step called after epoch.

  • gpu_stats – Whether to print GPU stats.

  • progress – Show progress.

trainable_params(module: Module, recurse=True) Iterator[Parameter]

Trainable parameters.

Retrieve all trainable parameters.

Parameters:
  • module – Module.

  • recurse – Whether to also include parameters of all submodules.

Returns:

Module parameters.

tweak_module_(module: Module, class_or_tuple, must_exist=True, recursive=True, **kwargs)

Tweak module.

Set attributes for all modules that are instances of given class_or_tuple.

Examples

>>> import celldetection as cd, torch.nn as nn
>>> model = cd.models.ResNet18(in_channels=3)
>>> cd.tweak_module_(model, nn.BatchNorm2d, momentum=0.05)  # sets momentum to 0.05

Notes

This is an in-place operation.

Parameters:
  • module – PyTorch Module.

  • class_or_tuple – All instances of given class_or_tuple are to be tweaked.

  • must_exist – If True an AttributeError is raised if keywords do not exist.

  • recursive – Whether to search for modules recursively.

  • **kwargs – Attributes to be tweaked: <attribute_name>=<value>.

unfreeze_(module: Module, recurse=True)

Unfreeze.

Unfreezes a module by setting param.requires_grad=True and calling module.train().

Parameters:
  • module – Module.

  • recurse – Whether to unfreeze parameters of this layer and submodules or only parameters that are direct members of this module.

unfreeze_submodules_(module: Module, *names, recurse=True)

Unfreeze specific submodules.

Unfreezes submodules by setting param.requires_grad=True and calling submodule.train().

Parameters:
  • module – Module.

  • names – Names of submodules.

  • recurse – Whether to unfreeze parameters of specified modules and their respective submodules or only parameters that are direct members of the specified submodules.

update_dict_(dst, src, override=False, keys: List[str] | Tuple[str] | None = None)
weight_norm_(module, class_or_tuple=<class 'torch.nn.modules.conv.Conv2d'>, recursive=True, name='weight', add_repr=False, **kwargs)

Weight normalization.

Applies weight normalization to parameters of all occurrences of class_or_tuple in the given module.

Note

This is an inplace operation.

References

Parameters:
  • module – Module.

  • class_or_tuple – Class or tuple of classes whose parameters are to be normalized.

  • recursive – Whether to search for modules recursively.

  • name – Name of weight parameter.

  • add_repr – Whether to indicate use of weight norm in a module’s representation. Note that this may impair pickling.

  • **kwargs – Additional keyword arguments for torch.nn.utils.weight_norm.

wrap_module_(module: Module, class_or_tuple, wrapper, recursive=True, **kwargs)