Dataset for Stable Diffusion

1.Dataset for Stable Diffusion

1.1 Dataset

采用Flicker8k数据集,该数据集有两个文件,第一个文件为Flicker8k_Dataset (全部为图片),第二个文件为Flickr8k.token.txt (含两列image_id和caption),其中一个image_id对应5个caption (sentence)

1.2 Dataset description file

{“images”: [ {“sentids”: [ ],“imgid”: 0,“sentences”:[{“tokens”:[ ]}, {“tokens”:[ ], “raw”: “…”, “imgid”:0, “sentid”:0}, …, “split”: “train”, “filename”: …jpg}, {“sentids”…} ], “dataset”: “flickr8k”}

“sentids”:[0,1,2,3,4]caption 的 id 范围(一个image对应5个caption,所以sentids从0到4)
“imgid”:0image 的 id(从0到7999共8000张image)
“sentences”:[ ]包含一张照片的5个caption
“tokens”:[ ]每个caption分割为单个word
“raw”: " "每个token连接起来的caption
“imgid”: 0与caption相匹配的image的id
“sentid”: 0imag0对应的具体的caption的id
“split”:" "将该image和对应caption划分到训练集or验证集or测试集


1.3 Process Datasets


import json
import os
import random
from collections import Counter, defaultdict
from matplotlib import pyplot as plt
from PIL import Image
from argparse import Namespace
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from import Dataset
import torchvision
import torchvision.transforms as transforms

def create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):
        dataset: Name of the dataset
        captions_per_image: Number of captions per image
        min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)
        max_len: Maximum number of words in a caption. Captions longer than this will be truncated.
        A vocabulary file: vocab.json
        Three dataset files: train_data.json, val_data.json, test_data.json
    # Paths for reading data and saving processed data
    # Path to the dataset JSON file
    flickr_json_path = ".../sd/data/dataset_flickr8k.json"
    # Folder containing images
    image_folder = ".../sd/data/Flicker8k_Dataset"
    # Folder to save processed results
    # The % operator is used to format the string by replacing %s with the value of the dataset variable.
    # For example, if dataset is "flickr8k", the resulting output_folder will be
    # /home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion/sd/data/flickr8k.
    output_folder = ".../sd/data/%s" % dataset

    # Ensure output directory exists
    os.makedirs(output_folder, exist_ok=True)
    print(f"Output folder: {output_folder}")

    # Read the dataset JSON file
    with open(file=flickr_json_path, mode="r") as j:
        data = json.load(fp=j)
    # Initialize containers for image paths, captions, and vocabulary
    # Dictionary to store image paths
    image_paths = defaultdict(list)
    # Dictionary to store image captions
    image_captions = defaultdict(list)
    # Count the number of elements, then count and return a dictionary
    # key:element value:the number of elements.
    vocab = Counter()
    # read from file dataset_flickr8k.json
    for img in data["images"]:  # Iterate over each image in the dataset
        split = img["split"]  # Determine the split (train, val, or test) for the image
        captions = []
        for c in img["sentences"]:  # Iterate over each caption for the image
            # Update word frequency count, excluding test set data
            if split != "test":  # Only update vocabulary for train/val splits
                # c['tokens'] is a list, The number of occurrences of each word in the list is increased by one
                vocab.update(c['tokens'])  # Update vocabulary with words in the caption
            # Only consider captions that are within the maximum length
            if len(c["tokens"]) <= max_len:
                captions.append(c["tokens"])  # Add the caption to the list if it meets the length requirement

        if len(captions) == 0:  # Skip images with no valid captions
        # Construct the full image path/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion
        # image_folder + image_name
        # ./Flicker8k_Dataset/img['filename']
        path = os.path.join(image_folder, img['filename'])
        # Save the full image path and its captions in the respective dictionaries

1.Iterates over the 5 captions for 下面代码引用自:[Flickr8k数据集处理](仅作学习使用)the 250th image.
2.Retrieves the word indices for each caption.
3.Converts the word indices to words using vocab_idx2word.
4.Joins the words to form complete sentences.
5.Prints each caption.
import json
from PIL import Image
from matplotlib import pyplot as plt
# Load the vocabulary from the JSON file
with open('.../sd/data/flickr8k/vocab.json', 'r') as f:
    vocab = json.load(f)  # Load the vocabulary from the JSON file into a dictionary
# Create a dictionary to map indices to words
vocab_idx2word = {idx: word for word, idx in vocab.items()}
# Load the test data from the JSON file
with open('.../sd/data/flickr8k/test_data.json', 'r') as f:
    data = json.load(f)  # Load the test data from the JSON file into a dictionary
# Open and display the 250th image in the test set
# Open the image at index 250 in the 'IMAGES' list
content_img =['IMAGES'][250])
plt.figure(figsize=(6, 6))
# Print the lengths of the data, image list, and caption list
# Print the number of keys in the dataset dictionary (should be 2: 'IMAGES' and 'CAPTIONS')
print(len(data['IMAGES']))  # Print the number of images in the 'IMAGES' list
print(len(data["CAPTIONS"]))  # Print the number of captions in the 'CAPTIONS' list
# Display the captions for the 300th image
# Iterate over the 5 captions associated with the 300th image
for i in range(5):
    # Get the word indices for the i-th caption of the 300th image
    word_indices = data['CAPTIONS'][250 * 5 + i]
    # Convert indices to words and join them to form a caption
    print(''.join([vocab_idx2word[idx] for idx in word_indices]))

1.4 Dataloader


import json
import os
import random
from collections import Counter, defaultdict
from PIL import Image
import torch
from import Dataset
from torch.utils import data
import torchvision.transforms as transforms

class ImageTextDataset(Dataset):
    Pytorch Dataset class to generate data batches using torch DataLoader
    def __init__(self, dataset_path, vocab_path, split, captions_per_image=5, max_len=30, transform=None):
            dataset_path: Path to the JSON file containing the dataset
            vocab_path: Path to the JSON file containing the vocabulary
            split: The dataset split, which can be "train", "val", or "test"
            captions_per_image: Number of captions per image
            max_len: Maximum number of words per caption
            transform: Image transformation methods
        self.split = split
        # Validate that the split is one of the allowed values
        assert self.split in {"train", "val", "test"}
        # Store captions per image
        self.cpi = captions_per_image
        # Store maximum caption length
        self.max_len = max_len

        # Load the dataset
        with open(dataset_path, "r") as f:
   = json.load(f)

        # Load the vocabulary
        with open(vocab_path, "r") as f:
            self.vocab = json.load(f)

        # Store the image transformation methods
        self.transform = transform

        # Number of captions in the dataset
        # Calculate the size of the dataset
        self.dataset_size = len(["CAPTIONS"])

    def __getitem__(self, i):
            Retrieve the i-th sample from the dataset
        # Get [i // self.cpi]-th image corresponding to the i-th sample (each image has multiple captions)
        img =['IMAGES'][i // self.cpi]).convert("RGB")
        # Apply image transformation if provided
        if self.transform is not None:
            # Apply the transformation to the image
            img = self.transform(img)
        # Get the length of the caption
        caplen = len(["CAPTIONS"][i])
        # Pad the caption if its length is less than max_len
        pad_caps = [self.vocab['<pad>']] * (self.max_len + 2 - caplen)
        # Convert the caption to a tensor and pad it
        caption = torch.LongTensor(["CAPTIONS"][i] + pad_caps)
        return img, caption, caplen  # Return the image, caption, and caption length

    def __len__(self):
        return self.dataset_size  # Number of samples in the dataset

def make_train_val(data_dir, vocab_path, batch_size, workers=4):
        Create DataLoader objects for training, validation, and testing sets.
            data_dir: Directory where the dataset JSON files are located
            vocab_path: Path to the vocabulary JSON file
            batch_size: Number of samples per batch
            workers: Number of subprocesses to use for data loading (default is 4)
            train_loader: DataLoader for the training set
            val_loader: DataLoader for the validation set
            test_loader: DataLoader for the test set
    # Define transformation for training set
    train_tx = transforms.Compose([
        transforms.Resize(256),  # Resize images to 256x256
        transforms.ToTensor(),  # Convert image to PyTorch tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet mean and std

    val_tx = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # Create dataset objects for training, validation, and test sets
    train_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "train_data.json"), vocab_path=vocab_path,
                                 split="train", transform=train_tx)

    vaild_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "val_data.json"), vocab_path=vocab_path,
                                 split="val", transform=val_tx)

    test_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "test_data.json"), vocab_path=vocab_path,
                                split="test", transform=val_tx)
    # Create DataLoader for training set with data shuffling
    train_loder = data.DataLoader(
        dataset=train_set, batch_size=batch_size, shuffer=True,
        num_workers=workers, pin_memory=True
    # Create DataLoader for validation set without data shuffling
    val_loder = data.DataLoader(
        dataset=vaild_set, batch_size=batch_size, shuffer=False,
        num_workers=workers, pin_memory=True, drop_last=False
    # Create DataLoader for test set without data shuffling
    test_loder = data.DataLoader(
        dataset=test_set, batch_size=batch_size, shuffer=False,
        num_workers=workers, pin_memory=True, drop_last=False

    return train_loder, val_loder, test_loder


1.5 Training、Validation、Test Set



Training Dataset: The sample of data used to fit the model.

The actual dataset that we use to train the model (weights and biases in the case of a Neural Network). The model sees and learns from this data.


Validation Dataset: The sample of data used to provide an unbiased evaluation of a model fit on the training dataset while tuning model hyperparameters. The evaluation becomes more biased as skill on the validation dataset is incorporated into the model configuration.

The validation set is used to evaluate a given model, but this is for frequent evaluation. We, as machine learning engineers, use this data to fine-tune the model hyperparameters. Hence the model occasionally sees this data, but never does it “Learn” from this. We use the validation set results, and update higher level hyperparameters. So the validation set affects a model, but only indirectly. The validation set is also known as the Dev set or the Development set. This makes sense since this dataset helps during the “development” stage of the model.


Test Dataset: The sample of data used to provide an unbiased evaluation of a final model fit on the training dataset.

The Test dataset provides the gold standard used to evaluate the model. It is only used once a model is completely trained(using the train and validation sets). The test set is generally what is used to evaluate competing models (For example on many Kaggle competitions, the validation set is released initially along with the training set and the actual test set is only released when the competition is about to close, and it is the result of the the model on the Test set that decides the winner). Many a times the validation set is used as the test set, but it is not good practice. The test set is generally well curated. It contains carefully sampled data that spans the various classes that the model would face, when used in the real world.

Now that you know what these datasets do, you might be looking for recommendations on how to split your dataset into Train, Validation and Test sets.

This mainly depends on 2 things. First, the total number of samples in your data and second, on the actual model you are training.

Some models need substantial data to train upon, so in this case you would optimize for the larger training sets. Models with very few hyperparameters will be easy to validate and tune, so you can probably reduce the size of your validation set, but if your model has many hyperparameters, you would want to have a large validation set as well(although you should also consider cross validation). Also, if you happen to have a model with no hyperparameters or ones that cannot be easily tuned, you probably don’t need a validation set too!

All in all, like many other things in machine learning, the train-test-validation split ratio is also quite specific to your use case and it gets easier to make judge ment as you train and build more and more models.

