cd
Primary symbols include classes and functions for visualization, configuration, timing and generally useful utilities.
Visualization
- label_cmap(labels: ndarray, colors: str | ndarray = 'rand', zero_val: float | tuple | list = 0.0, rgba: bool = True, alpha: float | None = None)
Label colormap.
Applies a colormap to a label image.
- Parameters:
labels – Label image. Typically Array[h, w].
colors – Either ‘rand’ or one of [‘Pastel1’, ‘Pastel2’, ‘Paired’, ‘Accent’, ‘Dark2’, ‘Set1’, ‘Set2’, ‘Set3’, ‘tab10’, ‘tab20’, ‘tab20b’, ‘tab20c’] (see matplotlib’s qualitative colormaps) or Array[n, c].
zero_val – Special color for the zero label (usually background).
rgba – Whether to add an alpha channel to rgb colors.
alpha – Specific alpha value.
- Returns:
Mapped labels. E.g. rgba mapping Array[h, w] -> Array[h, w, 4].
- random_colors_hsv(num, hue_range=(0, 180), saturation_range=(60, 133), value_range=(180, 256))
- get_axes(fig=None) List[SubplotBase]
Get current pyplot axes.
- Parameters:
fig – Optional Figure.
- Returns:
List of Axes.
- imshow(image: ndarray | Tensor, figsize=None, **kw)
Imshow.
PyPlot’s imshow function with benefits.
- Parameters:
image – Image. Valid Formats: Array[h, w], Array[h, w, c] or Array[n, h, w, c], Tensor[h, w], Tensor[c, h, w] or Tensor[n, c, h, w]. Images without channels or just one channel are plotted as grayscale images by default.
figsize – Figure size. If specified, a new
plt.figure(figsize=figsize)
is created.**kw – Imshow keyword arguments.
- imshow_col(*images, titles=None, figsize=(3, 3), tight=True, **kwargs)
Imshow row.
Display a list of images in a row.
- Parameters:
*images – Images.
titles – Titles. Either string or list of strings (one for each image).
figsize – Figure size per image.
tight – Whether to use tight layout.
**kwargs – Keyword arguments for
cd.imshow
.
- imshow_grid(*images, titles=None, figsize=(3, 3), tight=True, **kwargs)
Imshow grid.
Display a list of images in a NxN grid.
- Parameters:
*images – Images.
titles – Titles. Either string or list of strings (one for each image).
figsize – Figure size per image.
tight – Whether to use tight layout.
**kwargs – Keyword arguments for
cd.imshow
.
- imshow_row(*images, titles=None, figsize=(3, 3), tight=True, **kwargs)
Imshow row.
Display a list of images in a row.
- Parameters:
*images – Images.
titles – Titles. Either string or list of strings (one for each image).
figsize – Figure size per image.
tight – Whether to use tight layout.
**kwargs – Keyword arguments for
cd.imshow
.
- plot_box(x_min, y_min, x_max, y_max, linewidth=1, edgecolor='#4AF626', facecolor='none', text=None, **kwargs)
- plot_boxes(boxes, texts: List[str] | None = None, **kwargs)
- plot_contours(contours, contour_line_width=2, contour_linestyle='-', fill=0.2, color=None, texts: list | None = None, **kwargs)
- plot_mask(mask, alpha=1)
- plot_score(score: float, x, y, cls: int | str | None = None, cls_names: dict | None = None, **kwargs)
- plot_text(text, x, y, color='black', stroke_width=5, stroke_color='w')
- quiver_plot(vector_field, image=None, cmap='gray', figsize=None, qcmap='twilight', linewidth=0.125, width=0.19, alpha=0.7)
Quiver plot.
Plots a 2d vector field. Can be used to visualize local refinement tensor.
- Parameters:
vector_field – Array[2, w, h].
image – Array[h, w(, 3)].
cmap – Image color map.
figsize – Figure size.
qcmap – Quiver color map. Consider seaborn’s: qcmap = ListedColormap(sns.color_palette(“husl”, 8).as_hex())
linewidth – Quiver line width.
width – Quiver width.
alpha – Quiver alpha.
- save_fig(filename, close=True)
Save Figure.
Save current Figure to disk.
- Parameters:
filename – Filename, e.g.
image.png
.close – Whether to close all unhandled Figures. Do not close them if you intend to call
plt.show()
.
- show_detection(image=None, contours=None, coordinates=None, boxes=None, scores=None, masks=None, figsize=None, label_stack=None, classes: List[str] | List[int] | None = None, class_names: dict | None = None, contour_line_width=2, contour_linestyle='-', fill=0.2, cmap=Ellipsis, **kwargs)
Config
- class Config(**kwargs)
Config.
Just a
dict
with benefits.Config
objects treat values as attributes, print nicely, and can be saved and loaded to/from json files. Thehash
method also offers a unique and compact string representation of theConfig
content.Examples
>>> import celldetection as cd, torch.nn as nn >>> conf = cd.Config(optimizer={'Adam': dict(lr=.001)}, epochs=100) >>> conf Config( (optimizer): {'Adam': {'lr': 0.001}} (epochs): 100 ) >>> conf.to_json('config.json') >>> conf.hash() 'cf647b987ca37eb954d8bd01df01809e' >>> conf.epochs = 200 ... conf.epochs 200 >>> module = nn.Conv2d(1, 2, 3) >>> optimizer = cd.conf2optimizer(conf.optimizer, module.parameters()) ... optimizer Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 0.001 weight_decay: 0 )
- Parameters:
**kwargs – Items.
- args(fn: Callable)
Examples
>>> conf = cd.Config(a=1, b=2, c=42) >>> def f(a, b): ... return a + b >>> f(*conf.args(f)) 3
- Parameters:
fn –
Returns:
- extra_repr() str
- static from_json(filename)
- hash() str
- kwargs(fn: Callable)
Examples
>>> conf = cd.Config(a=1, b=2, c=42) >>> def f(a, b): ... return a + b >>> f(**conf.kwargs(f)) 3
- Parameters:
fn –
Returns:
- load(filename)
- to_dict() dict
- to_json(filename)
- to_txt(filename, mode='w', **kwargs)
- class Schedule(**kwargs)
Schedule.
Provides an easy interface to the cross product of different configurations.
Examples
>>> s = cd.Schedule( ... lr=(0.001, 0.0005), ... net=('resnet34', 'resnet50'), ... epochs=100 ... ) ... len(s) 4 >>> s[:] [Config( (epochs): 100 (lr): 0.001 (net): 'resnet34' ), Config( (epochs): 100 (lr): 0.001 (net): 'resnet50' ), Config( (epochs): 100 (lr): 0.0005 (net): 'resnet34' ), Config( (epochs): 100 (lr): 0.0005 (net): 'resnet50' )] >>> for config in s: ... print(config.lr, config.net, config.epoch) 0.001 resnet34 100 0.001 resnet50 100 0.0005 resnet34 100 0.0005 resnet50 100
- Parameters:
**kwargs – Configurations. Possible item layouts:
<name>: <static setting>
,<name>: (<option1>, ..., <optionN>)
,<name>: [<option1>, ..., <optionN>]
,<name>: {<option1>, ..., <optionN>}
.
- add(d: dict | None = None, conditions: dict | None = None, **kwargs)
Add setting to schedule.
Examples
>>> schedule = cd.Schedule(model=('resnet18', 'resnet50'), batch_size=8) ... schedule.add(batch_size=(16, 32), conditions={'model': 'resnet18'}) ... schedule[:] [Config( (batch_size): 16, (model): resnet18, ), Config( (batch_size): 32, (model): resnet18, ), Config( (batch_size): 8, (model): resnet50, )]
>>> schedule = cd.Schedule(model=('resnet18', 'resnet50')) ... schedule.add(batch_size=(16, 32), conditions={'model': 'resnet18'}) ... schedule[:] [Config( (model): resnet18, (batch_size): 16, ), Config( (model): resnet18, (batch_size): 32, ), Config( (model): resnet50, )]
>>> schedule = cd.Schedule(model=('resnet18', 'resnet50'), batch_size=(64, 128, 256)) ... schedule.add(batch_size=(16, 32), conditions={'model': 'resnet50'}) ... schedule[:] [Config( (batch_size): 64 (model): 'resnet18' ), Config( (batch_size): 16 (model): 'resnet50' ), Config( (batch_size): 32 (model): 'resnet50' ), Config( (batch_size): 128 (model): 'resnet18' ), Config( (batch_size): 256 (model): 'resnet18' )]
- Parameters:
d – Dictionary of settings.
conditions – If set, added settings are only applied if conditions are met. Note: Conditioned settings replace/override existing settings if conditions are met.
**kwargs – Configurations. Possible item layouts: <name>: <static setting> <name>: (<option1>, …, <optionN>) <name>: [<option1>, …, <optionN>] <name>: {<option1>, …, <optionN>}
- property configs
- static from_json(filename)
- get_multiples(num=2)
- load(filename)
- property product
- to_dict()
- to_json(filename)
- conf2augmentation(settings: dict) Compose
Config to augmentation.
Maps settings to composed augmentation workflow using
albumentations
.Examples
>>> import celldetection as cd >>> cd.conf2augmentation({ ... 'RandomRotate90': dict(p=.5), ... 'Transpose': dict(p=.5), ... }) Compose([ RandomRotate90(always_apply=False, p=0.5), Transpose(always_apply=False, p=0.5), ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})
- Parameters:
settings – Settings dictionary as
{name: kwargs}
.- Returns:
A.Compose
object.
- conf2call(settings: dict | str, origin, **kwargs)
Config to call.
Examples
>>> import celldetection as cd >>> model = cd.conf2call('ResNet18', cd.models, in_channels=1) >>> model = cd.conf2call({'ResNet18': dict(in_channels=1)}, cd.models)
- Parameters:
settings – Name or dictionary as
{name: kwargs}
. Name must be the symbol’s name that is to be retrieved fromorigin
.origin – Origin.
**kwargs – Additional keyword arguments for the call of retrieved symbol.
- Returns:
Return value of the call of retrieved symbol.
- conf2optimizer(settings: dict, params)
Config to optimizer.
Examples
>>> import celldetection as cd >>> module = nn.Conv2d(1, 2, 3) >>> optimizer = cd.conf2optimizer({'Adam': dict(lr=.0002, betas=(0.5, 0.999))}, module.parameters()) ... optimizer Adam ( Parameter Group 0 amsgrad: False betas: (0.5, 0.999) eps: 1e-08 lr: 0.0002 weight_decay: 0 )
- Parameters:
settings –
params –
Returns:
- conf2scheduler(settings: dict, optimizer, origins=None)
- conf2tweaks_(settings: dict, module: Module)
Config to tweaks.
Apply tweaks to module.
Notes
If module does not contain specified objects, nothing happens.
Examples
>>> import celldetection as cd, torch.nn as nn >>> model = cd.models.ResNet18(in_channels=3) >>> cd.conf2tweaks_({nn.BatchNorm2d: dict(momentum=0.05)}, model) # sets momentum to 0.05 >>> cd.conf2tweaks_({'BatchNorm2d': dict(momentum=0.42)}, model) # sets momentum to 0.42 >>> cd.conf2tweaks_({'LeakyReLU': dict(negative_slope=0.2)}, model) # sets negative_slope to 0.2
- Parameters:
settings – Settings dictionary as
{name: kwargs}
.module – Module that is to be tweaked.
Timing
- print_timing(name, seconds)
- start_timer(name, cuda=True, collect=True)
Keyword PyTorch timer.
Can be used to measure PyTorch GPU times.
- Parameters:
name – Keyword
Returns:
- stop_timer(name, cuda=True, verbose=True)
Util
- class Bytes
Bytes.
Printable integer that represents Bytes.
- UNITS = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB', 'BiB']
- class Dict(**kwargs)
Dictionary.
Just a
dict
that treats values like attributes.Examples
>>> import celldetection as cd >>> d = cd.Dict(my_value=42) >>> d.my_value 42 >>> d.my_value += 1 >>> d.my_value 43
- Parameters:
**kwargs –
- class GpuStats(delimiter=', ')
GPU Statistics.
Simple interface to print live GPU statistics from
pynvml
.Examples
>>> import celldetection as cd >>> stat = cd.GpuStats() # initialize once >>> print(stat) # print current statistics gpu0(free: 22.55GB, used: 21.94GB, util: 93%), gpu1(free: 1.03GB, used: 43.46GB, util: 98%)
- Parameters:
delimiter – Delimiter used for printing.
- class NormProxy(norm, **kwargs)
Norm Proxy.
Examples
>>> GroupNorm = NormProxy('groupnorm', num_groups=32) ... GroupNorm(3) GroupNorm(32, 3, eps=1e-05, affine=True) >>> GroupNorm = NormProxy(nn.GroupNorm, num_groups=32) ... GroupNorm(3) GroupNorm(32, 3, eps=1e-05, affine=True) >>> BatchNorm2d = NormProxy('batchnorm2d', momentum=.2) ... BatchNorm2d(3) BatchNorm2d(3, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True) >>> BatchNorm2d = NormProxy(nn.BatchNorm2d, momentum=.2) ... BatchNorm2d(3) BatchNorm2d(3, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
- Parameters:
norm – Norm class or name.
**kwargs – Keyword arguments.
- class Percent(x=0, /)
Percent.
Printable float that represents percentage.
- class Tiling(tile_size: tuple, context_shape: tuple, overlap=0)
- add_to_loss_dict(d: dict, key: str, loss: Tensor, weight=None)
- append_hash_to_filename(filename, num=None, ext=True)
- asnumpy(v)
As numpy.
Converts all Tensors to numpy arrays.
Notes
Works recursively.
The following input items are not altered: Numpy array, int, float, bool, str
- Parameters:
v – Tensor or list/tuple/dict of Tensors.
- Returns:
Input with Tensors converted to numpy arrays.
- copy_script(dst, no_script_okay=True, frame=None, verbose=False)
Copy current script.
Copies the script from where this function is called to
dst
. By default, nothing happens if this function is not called from within a script.- Parameters:
dst – Copy destination. Filename or folder.
no_script_okay – If
False
raiseFileNotFoundError
if no script is found.verbose – Whether to print source and destination when copying.
- count_submodules(module: Module, class_or_tuple) int
Count submodules.
Count the number of submodules of the specified type(-es).
Examples
>>> count_submodules(cd.models.U22(1, 0), nn.Conv2d) 22
- Parameters:
module – Module.
class_or_tuple – All instances of given class_or_tuple are to be counted.
- Returns:
Number of submodules.
- dict_hash(dictionary: Dict[str, Any]) str
MD5 hash of a dictionary.
References
https://www.doc.ic.ac.uk/~nuric/coding/how-to-hash-a-dictionary-in-python.html
- Parameters:
dictionary – A dictionary.
- Returns:
Md5 hash of the dictionary as string.
- ensure_num_tuple(v, num=2, msg='')
- exponential_moving_average_(module_avg, module, alpha=0.999, alpha_non_trainable=0.0, buffers=True)
Exponential moving average.
Update the variables of
module_avg
to be slightly closer tomodule
.References
Notes
Whether a parameter is trainable or not is checked on
module
module_avg
can be on different device and entirely frozen
- Parameters:
module_avg – Average module. The parameters of this model are to be updated.
module – Other Module.
alpha – Fraction of trainable parameters of
module_avg
; (1 - alpha) is fraction of trainable parameters ofmodule
.alpha_non_trainable – Same as
alpha
, but for non-trainable parameters.buffers – Whether to copy buffers from
module
tomodule_avg
.
- fetch_image(url, numpy=True)
Fetch image from URL.
Download an image from URL and convert it to a numpy array or PIL Image.
- Parameters:
url – URL
numpy – Whether to convert PIL Image to numpy array.
- Returns:
PIL Image or numpy array.
- fetch_model(name, map_location=None, **kwargs)
Fetch model from URL.
Loads model or state dict from URL.
- Parameters:
name – Model name hosted on celldetection.org or url. Urls must start with ‘http’.
map_location – A function, torch.device, string or a dict specifying how to remap storage locations.
**kwargs – From the doc of torch.models.utils.load_state_dict_from_url.
- from_h5(filename, *keys, **keys_slices)
From h5.
Reads data from hdf5 file.
- Parameters:
filename – Filename.
*keys – Keys to read.
**keys_slices – Keys with indices or slices. E.g. from_h5(‘file.h5’, ‘key0’, key=slice(0, 42)).
- Returns:
Data from hdf5 file. As tuple if multiple keys are provided.
- from_json(filename)
From JSON.
Load object from JSON file with name
filename
.- Parameters:
filename – File name.
- frozen_params(module: Module, recurse=True) Iterator[Parameter]
Frozen parameters.
Retrieve all frozen parameters.
- Parameters:
module – Module.
recurse – Whether to also include parameters of all submodules.
- Returns:
Module parameters.
- gaussian_kernel(kernel_size, sigma=-1, nd=2) ndarray
Get Gaussian kernel.
Constructs and returns a Gaussian kernel.
- Parameters:
kernel_size – Kernel size as int or tuple. It should be odd and positive.
sigma – Gaussian standard deviation as float or tuple. If it is non-positive, it is computed from kernel_size as
sigma = 0.3*((kernel_size-1)*0.5 - 1) + 0.8
.nd – Number of kernel dimensions.
- Returns:
Gaussian Kernel.
- get_device(module: Module | Tensor | device)
Get device.
Get device from Module.
- Parameters:
module – Module. If
module
is a string ortorch.device
already, it is returned as is.- Returns:
Device.
- get_nd_batchnorm(dim: int)
- get_nd_conv(dim: int)
- get_nd_dropout(dim: int)
- get_nd_linear(dim: int)
- get_nd_max_pool(dim: int)
- get_nn(item: str | Module | Type[Module], src=None, nd=None)
- get_tiling_slices(size: Sequence[int], crop_size: int | Sequence[int], strides: int | Sequence[int]) Iterable[slice] | Tuple[int]
Get tiling slices.
- Parameters:
size – Reference size as tuple.
crop_size – Crop size.
strides – Strides.
- Returns:
Iterator of tiling slices (each slice defining a tile), Number of tiles per dimension as tuple.
- Return type:
Iterable[slice], Tuple[int]
- get_warmup_factor(step, steps=1000, factor=0.001, method='linear')
- hash_file(filename)
- inject_extra_repr_(module, name, fn)
Inject extra representation.
Injects additional
extra_repr
function tomodule
. This can be helpful to indicate presence of hooks.Note
This is an inplace operation.
Notes
This op may impair pickling.
- Parameters:
module – Module.
name – Name of the injected function (only used to avoid duplicate injection).
fn – Callback function.
- iter_submodules(module: Module, class_or_tuple, recursive=True)
- load_image(name, method='imageio') ndarray
Load image.
Load image from URL or from filename via
imageio
orpytiff
.- Parameters:
name – URL (must start with
http
) or filename.method – Method to use for filenames.
- Returns:
Image.
- load_model(filename, map_location=None, **kwargs)
- lookup_nn(item: str, *a, src=None, call=True, inplace=True, nd=None, **kw)
Examples
>>> lookup_nn('batchnorm2d', 32) BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> lookup_nn(torch.nn.BatchNorm2d, 32) BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> lookup_nn('batchnorm2d', num_features=32) BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> lookup_nn('tanh') Tanh() >>> lookup_nn('tanh', call=False) torch.nn.modules.activation.Tanh >>> lookup_nn('relu') ReLU(inplace=True) >>> lookup_nn('relu', inplace=False) ReLU() >>> # Dict notation to contain all keyword arguments for calling in `item`. Always called once. ... lookup_nn(dict(relu=dict(inplace=True)), call=False) ReLU(inplace=True) >>> lookup_nn({'NormProxy': {'norm': 'GroupNorm', 'num_groups': 32}}, call=False) NormProxy(GroupNorm, kwargs={'num_groups': 32}) >>> lookup_nn({'NormProxy': {'norm': 'GroupNorm', 'num_groups': 32}}, 32, call=True) GroupNorm(32, 32, eps=1e-05, affine=True)
- Parameters:
item – Lookup item. None is equivalent to identity.
*a – Arguments passed to item if called.
src – Lookup source.
call – Whether to call item.
inplace – Default setting for items that take an inplace argument when called. As default is True, lookup_nn(‘relu’) returns a ReLu instance with inplace=True.
nd – If set, replace dimension statement (e.g. ‘2d’ in nn.Conv2d) with
nd
.**kw – Keyword arguments passed to item when it is called.
- Returns:
Looked up item.
- num_bytes(x: ndarray | Tensor)
Num Bytes.
Returns the size in bytes of the given ndarray or Tensor.
- Parameters:
x – Array or Tensor.
- Returns:
Bytes
- num_params(module: Module, trainable=None, recurse=True) int
Number of parameters.
Count the number of parameters.
- Parameters:
module – Module
trainable – Optionally filter for trainable or frozen parameters.
recurse – Whether to also include parameters of all submodules.
- Returns:
Number of parameters.
- print_to_file(*args, filename, mode='w', **kwargs)
- random_code_name(chars=4) str
Random code name.
Generates random code names that are somewhat pronounceable and memorable.
Examples
>>> import celldetection as cd >>> cd.random_code_name() kolo >>> cd.random_code_name(6) lotexo
- Parameters:
chars – Number of characters.
- Returns:
String.
- random_code_name_dir(directory='./out', chars=6, comm=None, root_rank=0)
Random code name directory.
Creates random code name and creates a subdirectory with said name under directory. Code names that are already taken (subdirectory already exists) are not reused.
- Parameters:
directory – Root directory.
chars – Number of characters for the code name.
comm – MPI Comm. If provided, code name and directory is automatically broadcasted to all ranks of comm.
root_rank – Root rank. Only the root rank creates code name and directory.
- Returns:
Tuple of code name and created directory.
- random_seed(seed, backends=False, deterministic_torch=True)
Set random seed.
Set random seed to
random
,np.random
,torch.backends.cudnn
andtorch.manual_seed
. Also advise torch to use deterministic algorithms.References
https://pytorch.org/docs/stable/notes/randomness.html
- Parameters:
seed – Random seed.
backends – Whether to also adapt backends. If set True cuDNN’s benchmark feature is disabled. This causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance. Also the selected algorithm is set to run deterministically.
deterministic_torch – Whether to set PyTorch operations to behave deterministically.
- reduce_loss_dict(losses: dict, divisor)
- replace_module_(module: Module, class_or_tuple, substitute: Type[Module] | Module, recursive=True, inherit_attr: list | str | dict | None = None, **kwargs)
Replace module.
Replace all occurrences of class_or_tuple in module with substitute.
Examples
>>> # Replace all ReLU activations with LeakyReLU ... cd.replace_module_(network, nn.ReLU, nn.LeakyReLU)
>>> # Replace all BatchNorm layers with InstanceNorm and inherit `num_features` attribute ... cd.replace_module_(network, nn.BatchNorm2d, nn.InstanceNorm2d, inherit_attr=['num_features'])
>>> # Replace all BatchNorm layers with GroupNorm and inherit `num_features` attribute ... cd.replace_module_(network, nn.BatchNorm2d, nn.GroupNorm, num_groups=32, ... inherit_attr={'num_channels': 'num_features'})
- Parameters:
module – Module.
class_or_tuple – Class or tuple of classes that are to be replaced.
substitute – Substitute class or object.
recursive – Whether to replace modules recursively.
inherit_attr – Attributes to be inherited. String, list or dict of attribute names. Attribute values are retrieved from replaced module and passed to substitute constructor. Formats:
'attr_name'
,['attr_name0', 'attr_name1', ...]
,{'substitute_kw0': 'attr_name0', ...}
**kwargs – Keyword arguments passed to substitute constructor if it is a class.
- save_fetchable_model(model: Module, filename, append_hash=16)
- spectral_norm_(module, class_or_tuple=<class 'torch.nn.modules.conv.Conv2d'>, recursive=True, name='weight', add_repr=False, **kwargs)
Spectral normalization.
Applies spectral normalization to parameters of all occurrences of
class_or_tuple
in the given module.Note
This is an inplace operation.
References
- Parameters:
module – Module.
class_or_tuple – Class or tuple of classes whose parameters are to be normalized.
recursive – Whether to search for modules recursively.
name – Name of weight parameter.
add_repr – Whether to indicate use of spectral norm in a module’s representation. Note that this may impair pickling.
**kwargs – Additional keyword arguments for
torch.nn.utils.spectral_norm
.
- tensor_to(inputs: list | tuple | dict | Tensor, *args, **kwargs)
Tensor to device/dtype/other.
Recursively calls
tensor.to(*args, **kwargs)
for allTensors
ininputs
.Notes
Works recursively.
Non-Tensor items are not altered.
- Parameters:
inputs – Tensor, list, tuple or dict. Non-Tensor objects are ignored. Tensors are substituted by result of
tensor.to(*args, **kwargs)
call.*args – Arguments. See docstring of
torch.Tensor.to
.**kwargs – Keyword arguments. See docstring of
torch.Tensor.to
.
- Returns:
Inputs with Tensors replaced by
tensor.to(*args, **kwargs)
.
- to_device(batch: list | tuple | dict | Tensor, device)
To device.
Move Tensors to device. Input can be Tensor, tuple of Tensors, list of Tensors or a dictionary of Tensors.
Notes
Works recursively.
Non-Tensor items are not altered.
- Parameters:
batch – Tensor, list, tuple or dict. Non-Tensor objects are ignored. Tensors are moved to
device
.device – Device.
- Returns:
Input with Tensors moved to device.
- to_h5(filename, mode='w', chunks=None, compression=None, overwrite=False, create_dataset_kw: dict | None = None, **kwargs)
To hdf5 file.
Write data to hdf5 file.
- Parameters:
filename – File name.
mode – Mode.
chunks – Chunks setting for created datasets. Chunk shape, or True to enable auto-chunking.
compression – Compression setting for created datasets. Legal values are ‘gzip’, ‘szip’, ‘lzf’. If an integer in range(10), this indicates gzip compression level. Otherwise, an integer indicates the number of a dynamically loaded compression filter.
overwrite – Whether to overwrite existing dataset.
create_dataset_kw – Additional keyword arguments for
h5py.File().create_dataset
.**kwargs – Data as
{dataset_name: data}
.
- to_json(filename, obj, mode='w')
To JSON.
Dump
obj
to JSON file with namefilename
.- Parameters:
filename – File name.
obj – Object.
mode – File mode.
- to_tiff(filename, image, mode='w', method='tile', bigtiff=True)
To tiff file.
Write
image
to tiff file usingpytiff
. By default, the tiff is tiled, s.t. crops can be read from disk without loading the entire image into memory first.Notes
pytiff
must be installed to use this function.
References
https://pytiff.readthedocs.io/en/master/quickstart.html
- Parameters:
filename – File name.
image – Image.
mode – Mode.
method – Method. Either
'tile'
or'scanline'
.bigtiff – Whether to use bigtiff format.
- train_epoch(model, train_loader, device, optimizer, desc=None, scaler=None, scheduler=None, gpu_stats=False, progress=True)
Basic train function.
Notes
Model should return dictionary: {‘loss’: Tensor[], …}
Batch from train_loader should be a dictionary: {‘inputs’: Tensor[…], …}
Model must be callable: model(batch[‘inputs’], targets=batch)
- Parameters:
model – Model.
train_loader – Data loader.
device – Device.
optimizer – Optimizer.
desc – Description, appears in progress print.
scaler – Gradient scaler. If set PyTorch’s autocast feature is used.
scheduler – Scheduler. Step called after epoch.
gpu_stats – Whether to print GPU stats.
progress – Show progress.
- trainable_params(module: Module, recurse=True) Iterator[Parameter]
Trainable parameters.
Retrieve all trainable parameters.
- Parameters:
module – Module.
recurse – Whether to also include parameters of all submodules.
- Returns:
Module parameters.
- tweak_module_(module: Module, class_or_tuple, must_exist=True, recursive=True, **kwargs)
Tweak module.
Set attributes for all modules that are instances of given class_or_tuple.
Examples
>>> import celldetection as cd, torch.nn as nn >>> model = cd.models.ResNet18(in_channels=3) >>> cd.tweak_module_(model, nn.BatchNorm2d, momentum=0.05) # sets momentum to 0.05
Notes
This is an in-place operation.
- Parameters:
module – PyTorch Module.
class_or_tuple – All instances of given class_or_tuple are to be tweaked.
must_exist – If True an AttributeError is raised if keywords do not exist.
recursive – Whether to search for modules recursively.
**kwargs – Attributes to be tweaked: <attribute_name>=<value>.
- update_dict_(dst, src, override=False, keys: List[str] | Tuple[str] | None = None)
- weight_norm_(module, class_or_tuple=<class 'torch.nn.modules.conv.Conv2d'>, recursive=True, name='weight', add_repr=False, **kwargs)
Weight normalization.
Applies weight normalization to parameters of all occurrences of
class_or_tuple
in the given module.Note
This is an inplace operation.
References
- Parameters:
module – Module.
class_or_tuple – Class or tuple of classes whose parameters are to be normalized.
recursive – Whether to search for modules recursively.
name – Name of weight parameter.
add_repr – Whether to indicate use of weight norm in a module’s representation. Note that this may impair pickling.
**kwargs – Additional keyword arguments for
torch.nn.utils.weight_norm
.
- wrap_module_(module: Module, class_or_tuple, wrapper, recursive=True, **kwargs)