发布于2019-08-20 18:49 阅读(2228) 评论(0) 点赞(7) 收藏(0)
这个ctc_loss很魔性,训练CRNN虐了我几个来回。
我的数据集图片大小不一,我是先等比例缩小到固定高度为32,宽度不定。
常见三个问题:
1.CTC Loss Error: invalidArgumentError: Not Enough time for target transition sequence.
2.CTC Loss Error: InvalidArgumentError: sequence_length(b) <= time
3.ctc_loss error “No valid path found.” (这个错误对模型收敛没有很大影响,只是出错的那一个batch参数没有更新优化。如果这个错误很少,可以忽略。如果这个错误很多的话就建议用下面方法优化一下训练集。)
导致这三个问题的原因,就是label_length 和input_length的取值问题。
1. CRNN一个主要优点就是可以识别任意长度的图片。在训练的时候,先统一将图片padding到一个固定的很长的宽度。然后input_length设置为你等比例缩小后,padding之前的图片的宽除以四。部分代码如下:
- Img = Image.open(imagepath).convert('L')
- ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
- input_length[i] = ResizedImg.shape[1] // 4
2. label_length很简单理解,就是ground truth的长度。
3. 如果你以为这样就完事大吉可以训练你就错了。因为你的图片可能有不合格的存在。导致问题3出现,loss变为inf。
4. 所以在训练前,应该过滤一遍所有训练集和验证集的图片。ctc_loss在计算预测结果和真值的loss的时候,会在你真值label中重复的字符之间插入空符,所以必须将label_length加上空符个数大于input_length的图片删除掉。而代码中的2,是我考虑有可能在label的开头和末尾存在空符。(我并没有验证这个想法,只是为了保险起见。)举个例子,你图片高度为32,宽度为160,那么input_length=40。label='abbbccddddcccaa',label_length=15,经过计算repreat_number为2(bbb)+1(cc)+3(dddd)+2(ccc)+1(aa),然后再加上开头结果的空符数2,最终等于11。也就是说必须满足label_length(15)+repreat_number(11)<=input_length(40)的图片才是合格的图片。部分代码如下:
- Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
- ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
- l = [len(list(g)) for k, g in itertools.groupby(Label)]
- repeat_number = 0
- for n in l:
- if n > 1:
- repeat_number += (n - 1)
- input_length = ResizedImg.shape[1] // 4
- if len(Label)+repeat_number+2 > input_length:
- continue
作者:iuie9493
链接:https://www.pythonheidong.com/blog/article/49549/91023bde25f65df6c4bb/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!