StyleGAN2(-ada) is known for generating photorealistic images of e.g. portraits [Figure 1] and is now widely used in research, education and entertainment. Deployment of these models can be either by providing the files directly, dedicated model hosting services or in-browser. The latter has so far been the most challenging, yet a good trade-off between usability and cost. Since my last try a couple of months ago, I’ve finally seen new technologies to properly achieve this. So in this post I will explain the steps required to run StyleGAN2-ada models in your browser using onnxruntime with the current knowledge and technologies (October 2021). Tested on Windows 10. This approach may work for StyleGAN3 in the future as NVLabs stated on their StyleGAN3 git: “This repository is an updated version of stylegan2-ada-pytorch”. Unfortunately some of the ops used in StyleGAN3 are not supported in ONNX yet (affine_grid_generator). In short, this post is on how you can run StyleGAN2 in a browser.

## Preparation

It is important that you have your models in SyleGAN2-ada PyTorch format. Either you have trained your model previously in StyleGAN2(-ada) PyTorch or you can convert the tensorflow versions using the legacy converter. This should allow you to use any official StyleGAN2 version models. Example code can be found below.

python legacy.py \
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \
--dest=stylegan2-cat-config-f.pkl

## PyTorch model to ONNX

For the next step we need to convert your model to the ONNX format. I’ve written a snippet here to load the model, assign z and c as input variables and have Y as the output variable. It uses opset 10 of ONNX. The code will probably give you some warnings, but in the end you will end up with a file with the onnx extension. This configuration was the only configuration that worked properly for me. Alternatively, on github you can find a converter script.

import dnnlib
import numpy as np
import torch
import legacy
import functools

device = torch.device('cpu')
with dnnlib.util.open_url(r'path\to\your\network.pkl') as f:
GG = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

# enforce 32 bit floats (might be redundant in later version)
GG.forward = functools.partial(GG.forward, force_fp32=True)

# set input and output names
in_names = [ "z" ] + [ "c"]
out_names = [ "Y" ]

#use a dummy input and label to determine graph
dummy_input = torch.from_numpy(np.random.RandomState(0).randn(1, GG.z_dim)).to(device)
label = torch.zeros([1, GG.c_dim], device=device)

# export as onnx file
torch.onnx.export(model=GG,
args=(dummy_input,label),
f="your_onnx_model.onnx",
input_names=in_names,
output_names=out_names,
verbose=True,
opset_version=10,
export_params=True,
do_constant_folding=False,
use_external_data_format=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX)

## Minimalist html + js

A minimalist html with included javascript can be found below. Note that for my model I am not using StyleGAN2 classes. So I don’t know if these will work if you have used them in your models. Name the file index.html (or download here). And place this in the same folder as your onnx file. This might not work locally, so I suggest trying this on a web server. Do note that for this example I am using a fixed 512*512 resolution for ease. In case your model uses a different resolution, change the image resolution values in the example below! Alternatively, on github you can find example html files.

<!DOCTYPE html>
<html>
<title>ONNX Runtime JavaScript example</title>
<body>
<canvas id="gan-canvas" width="512" height="512" style="border:1px solid #d3d3d3;"></canvas>
<!-- import ONNXRuntime Web from CDN -->
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
const ganCanvas = document.getElementById('gan-canvas');

// use an async context to call onnxruntime functions.
async function main() {
try {
// Create session
const session = await ort.InferenceSession.create('./your_onnx_model.onnx');

// Create array
const dataA = new Float64Array(512);

// Random vector with values from -2 to 2
for (var i=0;i<512;i++) {
dataA[i] =  (Math.random()-0.5)*4;
}

// Assign (if you use classes, you might need to expand this)
const tensorA = new ort.Tensor('float64', dataA, [1,512]);
const feeds = { z: tensorA};

// Run
const results = await session.run(feeds);

// Get the result
const dataC = results['Y'];

// Get the canvas
var ctx = ganCanvas.getContext('2d');

// first, create a new ImageData to contain our pixels
var imgData = ctx.createImageData(512, 512); // width x height

// get data pointer
const data = imgData.data;

// assign color info
var offsetD = 0;
const factorR = 512*512;
const factorB = 512*512*2;
for (var i = 0; i < (512*512); i++) {
data[offsetD]     = (dataC.data[i]*127.5)+128;
data[offsetD + 1] = (dataC.data[i + factorR]*127.5)+128;
data[offsetD + 2] = (dataC.data[i + factorB]*127.5)+128;
data[offsetD + 3] = 255;
offsetD += 4;
}

// fill canvas
ctx.putImageData(imgData, 0, 0);
} catch (e) {
document.write(failed to inference ONNX model: \${e}.);
}
}
main();
</script>
</body>
</html>

Opening the index.html file on a remote server should result in the images. An example [Figure 2] is given below as used in the Abstract Art Generator V2.

## Pros & Cons

There are some discussion points to mention using this approach.