Hi!
My name is Zhengyan Lambo, a 3rd Year Computer Engineering student at HKUST. Take a look at some of the projects and fun stuff I have done! Use the left "Labels" function to look for projects by category.
___________________________________________________________________________________________________________________
你好!我的名字是Zhengyan Lambo,我是香港科技大学的三年级计算机工程学生。看一下我做过的一些项目和有趣的东西吧! 使用左边的“标签”功能查看按类别分类的项目
[] Image classification with CNN
Image classification with CNN
What is Convolution Neural Network(CNN)?
CNN is a type of supervised deep-learning algorithm that uses matrix feature extraction along with a densely connected network for classification tasks. CNN is most popular for image classification, but it can also be used for audio, signal, and time series classification.
How does CNN work
CNN is composed of a feature extraction part and a classification part.
Features are extracted with matrix operations. Basically, a filter matrix moves across the dataset. When it is overlapping with data that has high similarity as the filter, it will output a number close to 1 for that position and 0 vice versa. This is done by multiplying each of the filter values with a corresponding value of the dataset and dividing their sum by the number of values of the filter. Then, a ReLu function is used to non-linearize the data.
There are options for feature extraction such as padding, stride, and filter window. Padding adds an empty margin around the dataset so that corner values are not overlooked. Stride defines how far the filter window moves each step. Filter window defines the dimension of the window.
Pooling is often used between applying feature extraction. Pooling partitions the data into sub-regions with given window dimensions, then represents that region with the max value of that region. This type of pooling is called max pooling, which is commonly used. There is also average pooling.
After final pooling, the feature-representative data is pasted to a densely connected neural network, which is just a basic ANN.
Why do we use CNN
There are a few benefits for using CNN instead of Artificial Neural Networks(ANN). Namely 1) reduced training size, 2) reduced overfitting, 3) Tolerant to distortion of data. However, these benefits require using pooling. It seems that the value of convolution filtering is to enable the pooling process, and has no data processing usage inherently.
Reduced training size can be achieved with pooling. As pooling extracts representative feature information from blocks of regions, it reduces the size of data. A higher pooling window and stride results in more reduction in size.
Pooling also represents the original data in terms of the location of key features, reducing the density and noise of the data before passing it to the dense layer. This makes the data more "sparse", reducing overfitting. This also makes the model more tolerant of data distortion.
Data Augmentation
CNN by itself does not handle distorted/rotated samples well. You need to provide distorted/rotated training samples so that it can account for them. Data augmentation is to create random distortion/rotation to training samples. In TensorFlow, data augmentation can be defined as a sequential pipeline.
Then, it can be applied as one step in the modeling training pipeline.
Transfer Learning
Transfer learning is to train a custom model based on a trained model. It can significantly reduce training costs and time, therefore it is a handy technique if you are training a model with many parameters.
To do transfer learning, you load the pre-trained model without its last output layer. This file is called a feature vector. TensorFlow has some pre-trained models at: Home | TensorFlow Hub (tfhub.dev). Then, you set trainable to false so the pretrained model weights are not changed. Lastly, you create your own model by adding new layers to the feature vector.
What happens is it only takes 4 epochs in my case for the transfer learning model to reach 90% accuracy. But it would take over 20 epochs if I train a CNN from scratch.
Ok, now it is time for a project. I am following this end-to-end tutorial from code basics. The project is about making a mobile app that can take a photo and tell you if a potato has blight disease and the severity. The tutorial will cover these topics:
Tensorflow input pipelines - tf.dataset
CNN
Data augmentation
tf lite, Quantization(reduce model size)
Google Cloud Functions, GCP
React js and React Native
tf serving, FastAPI
Training the model
Data processing
step 1: load images into tf.dataset with
step 2: split dataset into 80% train, 10% validation, 10% test
Step 3: Cache and prefetch data.
This makes training faster. Cache allows the train data to be only loaded to the RAM once, while prefetch allows the CPU to read the files while the GPU is training the data to save time.
step 4: Preprocessing data: resize, rescale, and data_augentation.
This makes the model more tolerant of images of different "personalities".
step 5: build, compile, and fit the model
Visualizing model performance
Create a prediction function. The function takes in a preprocessed image(eg. test dataset), and turns it into a numpy array. Then we add one dimension to the image array so that is a batch of images of size 1, which is the model's acceptable format. We get a softmax probability distribution, which we find the max confidence to get the most likely predicted class.
Finally, I can visualize a set of predictions to see how the model performs with the plt library(very handy!). As you can see below, the model is doing amazingly, only making 1 mistake. To be honest, I can't even tell between early and late blight!
Saving the model
Lastly, we save the model by
Making a REST API and a web interface
To be honest, I won't spend too much time learning to write a website. I just followed the instructions on setting up the React environment and ran the website. Here are the results.
However, I recently discovered this great Python module called "Streamlit" which allows people to create interactive web tools super easily. Since I am trying to learn Deep learning and AI, Streamlit is a handy skill that I should learn, since the projects I want to make will require some kind of interface. So, I will remake this website in Streamlit in the future, where I will go in-depth into how this potato disease classification interface is made.
Conclusion
In this project, I learned about Tensorflow data pipelines and Data Augmentation. I experienced the basic workflow of turning image sets into a working model. I also learned some nice "plt" methods to plot batches of images and prediction information. Furthermore, I also learned how to run a web tool on my local machine.
I tried to find images of potato blight online for the model and it didn't do very well. The model almost classifies everything as early blight. However, the model seems to perform better on images with white background as opposed to a natural background filled with vegetation. There may be a couple of reasons.
The number of samples is too small(~2200 total)
Imbalanced dataset: 1000 early blight, 1000 late blight, 150 healthy
Not enough data augmentation
Model overfitting
This will be the end of this article for now. As of today, I am going to explore LLM model API (Lang Chain) and web scraping. These skills will be needed to create some interesting projects.